From 301c60df37170a09b7914a3e5ec7a60c229f5864 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz Date: Thu, 29 May 2025 01:03:34 +0200 Subject: [PATCH 01/44] writing new IR --- Python/pyproject.toml | 2 +- TensorFrost/IR/include/IR.h | 114 ++++++++++++++++++++++++++++ TensorFrost/IR/include/Operations.h | 22 ++++++ TensorFrost/IR/include/Overloads.h | 64 ++++++++++++++++ TensorFrost/IR/src/IR.cpp | 0 TensorFrost/IR/src/Overloads.cpp | 0 6 files changed, 201 insertions(+), 1 deletion(-) create mode 100644 TensorFrost/IR/include/IR.h create mode 100644 TensorFrost/IR/include/Operations.h create mode 100644 TensorFrost/IR/include/Overloads.h create mode 100644 TensorFrost/IR/src/IR.cpp create mode 100644 TensorFrost/IR/src/Overloads.cpp diff --git a/Python/pyproject.toml b/Python/pyproject.toml index 39eb00ff..2d9c556b 100644 --- a/Python/pyproject.toml +++ b/Python/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "TensorFrost" -version = "0.7.4" +version = "0.8.0.dev0" description = "A static optimizing tensor compiler with a Python frontend" authors = [{name = "Mykhailo Moroz", email = "michael08840884@gmail.com"}] requires-python = ">=3.7" diff --git a/TensorFrost/IR/include/IR.h b/TensorFrost/IR/include/IR.h new file mode 100644 index 00000000..9f80eb11 --- /dev/null +++ b/TensorFrost/IR/include/IR.h @@ -0,0 +1,114 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace TensorFrost { + +using uint = unsigned int; + +struct Type { + std::string dtype; + size_t size = 0; // size in bytes +}; + +struct Value; +struct Op; +struct Arguments; + +using Attribute = std::variant; +using AttributeMap = std::unordered_map; + +struct Value { + int id; + Type type; + Op* producer = nullptr; +}; + +struct Block { + std::list> ops; + Op* append(std::unique_ptr op) { + ops.emplace_back(std::move(op)); + return ops.back().get(); + } +}; + +struct ExecutionContext { + std::unique_ptr base_block; + Block* current_block = base_block.get(); + std::vector stack; + + void BeginBlock(Op* op) { + stack.push_back(current_block); + current_block = new Block(); + } + + void EndBlock() { + if (!stack.empty()) { + current_block = stack.back(); + stack.pop_back(); + } else { + current_block = nullptr; + } + } + + void AddOp(std::unique_ptr op) { + if (!current_block) { + throw std::runtime_error("No current block to add operation to"); + } + current_block->append(std::move(op)); + } +}; + +enum class ArgType { + Input, + Output, + Memory, + Shape, + Count +}; + +struct Argument { + ArgType type; + Value* from_value = nullptr; + Op* target_op = nullptr; + int index = 0; +}; + +struct Op { + static ExecutionContext* current_context; + std::string opcode; + std::vector arguments; + std::vector outputs; + std::vector> blocks; + AttributeMap attributes; + + Op(std::string op_name) : opcode(std::move(op_name)) { + if (!current_context) { + throw std::runtime_error("No current execution context set for operation creation"); + } + current_context->AddOp(std::make_unique(*this)); + } + + std::string operator std::string() const { + return opcode; + } + + Op& binary(Op& other, const std::string& op_name) { + Op* new_op = new Op(op_name); + Argument arg1{ArgType::Input, &outputs[0], this, 0}; + Argument arg2{ArgType::Input, &other.outputs[0], &other, 0}; + new_op->arguments.push_back(arg1); + new_op->arguments.push_back(arg2); + new_op->outputs.push_back(Value{0, Type{"float", 4}, new_op}); + return *new_op; + } +}; + + + + +} // namespace ir + diff --git a/TensorFrost/IR/include/Operations.h b/TensorFrost/IR/include/Operations.h new file mode 100644 index 00000000..bab643d3 --- /dev/null +++ b/TensorFrost/IR/include/Operations.h @@ -0,0 +1,22 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace TensorFrost { + +struct OpSpec { + std::string name; + int arity = -1; // -1 variadic +}; + +inline std::unordered_map& registry() { + static std::unordered_map r; + return r; +} + +inline void reg(OpSpec s) { registry()[s.name] = std::move(s); } + +} // namespace ir \ No newline at end of file diff --git a/TensorFrost/IR/include/Overloads.h b/TensorFrost/IR/include/Overloads.h new file mode 100644 index 00000000..5d6fb342 --- /dev/null +++ b/TensorFrost/IR/include/Overloads.h @@ -0,0 +1,64 @@ +#pragma once +#include "IR.h" + +namespace TensorFrost { +using ArgValue = std::variant; + +// Arithmetic operations +Value& operator+(ArgValue& a, ArgValue& b); +Value& operator-(ArgValue& a, ArgValue& b); +Value& operator*(ArgValue& a, ArgValue& b); +Value& operator/(ArgValue& a, ArgValue& b); +Value& operator%(ArgValue& a, ArgValue& b); + +// Bitwise operations +Value& operator&(ArgValue& a, ArgValue& b); +Value& operator|(ArgValue& a, ArgValue& b); +Value& operator^(ArgValue& a, ArgValue& b); +Value& operator<<(ArgValue& a, ArgValue& b); +Value& operator>>(ArgValue& a, ArgValue& b); +Value& operator~(ArgValue& a); + +// Comparison operations +Value& operator==(ArgValue& a, ArgValue& b); +Value& operator!=(ArgValue& a, ArgValue& b); +Value& operator<(ArgValue& a, ArgValue& b); +Value& operator<=(ArgValue& a, ArgValue& b); +Value& operator>(ArgValue& a, ArgValue& b); +Value& operator>=(ArgValue& a, ArgValue& b); + +// Logical operations +Value& operator&&(ArgValue& a, ArgValue& b); +Value& operator||(ArgValue& a, ArgValue& b); +Value& operator!(ArgValue& a); + +// Increment and decrement operations +Value& operator++(ArgValue& a); +Value& operator--(ArgValue& a); + +// Mathematical functions +Value& abs(ArgValue& a); +Value& sqrt(ArgValue& a); +Value& pow(ArgValue& base, ArgValue& exponent); +Value& sin(ArgValue& a); +Value& cos(ArgValue& a); +Value& tan(ArgValue& a); +Value& log(ArgValue& a); +Value& exp(ArgValue& a); +Value& min(ArgValue& a, ArgValue& b); +Value& max(ArgValue& a, ArgValue& b); +Value& clamp(ArgValue& value, ArgValue& min_val, ArgValue& max_val); +Value& round(ArgValue& a); +Value& floor(ArgValue& a); +Value& ceil(ArgValue& a); +Value& select(ArgValue& condition, ArgValue& true_value, ArgValue& false_value); + + + + + + + + + +} diff --git a/TensorFrost/IR/src/IR.cpp b/TensorFrost/IR/src/IR.cpp new file mode 100644 index 00000000..e69de29b diff --git a/TensorFrost/IR/src/Overloads.cpp b/TensorFrost/IR/src/Overloads.cpp new file mode 100644 index 00000000..e69de29b From 09102c6950260f3a1e19a5a2e6c3cac2e5571362 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz Date: Sun, 1 Jun 2025 04:24:06 +0200 Subject: [PATCH 02/44] Working on new IR --- ProtoIR.txt | 97 ++++++++ TensorFrost/IR/include/Common.h | 109 ++++++++ TensorFrost/IR/include/ExecutionContext.h | 24 ++ TensorFrost/IR/include/IR.h | 114 --------- TensorFrost/IR/include/Operation.h | 23 ++ TensorFrost/IR/include/OperationArguments.h | 42 ++++ TensorFrost/IR/include/OperationBlocks.h | 36 +++ TensorFrost/IR/include/OperationRegistry.h | 20 ++ TensorFrost/IR/include/Operations.h | 22 -- TensorFrost/IR/include/Overloads.h | 262 ++++++++++++++++---- TensorFrost/IR/src/Common.cpp | 24 ++ TensorFrost/IR/src/ExecutionContext.cpp | 47 ++++ TensorFrost/IR/src/IR.cpp | 0 TensorFrost/IR/src/Operation.cpp | 30 +++ TensorFrost/IR/src/OperationArguments.cpp | 44 ++++ TensorFrost/IR/src/OperationBlocks.cpp | 83 +++++++ TensorFrost/IR/src/OperationRegistry.cpp | 90 +++++++ TensorFrost/IR/src/Overloads.cpp | 149 +++++++++++ 18 files changed, 1035 insertions(+), 181 deletions(-) create mode 100644 ProtoIR.txt create mode 100644 TensorFrost/IR/include/Common.h create mode 100644 TensorFrost/IR/include/ExecutionContext.h delete mode 100644 TensorFrost/IR/include/IR.h create mode 100644 TensorFrost/IR/include/Operation.h create mode 100644 TensorFrost/IR/include/OperationArguments.h create mode 100644 TensorFrost/IR/include/OperationBlocks.h create mode 100644 TensorFrost/IR/include/OperationRegistry.h delete mode 100644 TensorFrost/IR/include/Operations.h create mode 100644 TensorFrost/IR/src/Common.cpp create mode 100644 TensorFrost/IR/src/ExecutionContext.cpp delete mode 100644 TensorFrost/IR/src/IR.cpp create mode 100644 TensorFrost/IR/src/Operation.cpp create mode 100644 TensorFrost/IR/src/OperationArguments.cpp create mode 100644 TensorFrost/IR/src/OperationBlocks.cpp create mode 100644 TensorFrost/IR/src/OperationRegistry.cpp diff --git a/ProtoIR.txt b/ProtoIR.txt new file mode 100644 index 00000000..b48dbdf4 --- /dev/null +++ b/ProtoIR.txt @@ -0,0 +1,97 @@ +Node = [name, arguments, attributes] + + +n = input_dim(attributes{type=int32, input_index=0, dim_index=0}) +a = input(args{shape=[n]}, attributes{type=float32, input_index=0}) +b = sin(args{input=[a], shape=[n]}, attributes{type=float32}) +c = load(args{input=[b], indices=[0]}, attributes{type=float32}) +d = average(args{input=[b]}, attributes{type=float32}) +res = div(args{input=[d,c]}, attributes{type=float32}) +ids = parallel(args{shape=[n, n]}, attributes{type=tuple}) { + i = unpack_tuple(args{input=[ids,0]}, attributes{type=int32}) + j = unpack_tuple(args{input=[ids,0]}, attributes{type=int32}) + a0 = load(args{input=[b], indices=[i]}, attributes{type=float32}) + a1 = load(args{input=[b], indices=[j]}, attributes{type=float32}) + outer = mul(args{input=[a0, a1]}, attributes{type=float32, output_index=0}) +} + + +vmap (shape = [a, b, c]) { + A_2 = load(memory=[A], indices=[i,j]) + int(32) v2_0(32) = const(outputs=[v2_1, v2_2, ], data=[1], index=15, debug_index=32, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + float(32) v2_1(32) = dim_norm(inputs=[A_2], outputs=[v2_2, ], data=[0], shape=[v2_0(1)], index=16, debug_index=34, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + store(memory=[R(0.f)], inputs=[v2_1], indices=[i,i], shape=[v2_0(1)], index=17, debug_index=36, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + +} + +int(32) v1_0(32) = const(data=[4294967295], index=1, debug_index=4, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +int(32) v1_1(32) = const(data=[4294967295], index=2, debug_index=6, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +int(32) n(32) = input_shape(outputs=[v1_3, v2_8, v3_5, v3_1, v3_12, v3_7, v3_10, v3_17, v3_14, A, Q(0.f), R(0.f), R(0.f), ], flags={InputShapeDim(1), }, index=3, debug_index=8, debug_name=n, created_in=Tracing initial graph, created_in_func=None, ) +int(32) m(32) = input_shape(outputs=[v2_9, A_2, j, A_3, v2_4, v2_3, A_6, A_7, v3_15, v3_18, A, Q(0.f), ], flags={InputShapeDim(0), }, index=4, debug_index=10, debug_name=m, created_in=Tracing initial graph, created_in_func=None, ) +float(32) A(32) = memory(outputs=[A_2, A_3, A_5, v2_20, A_4, A_6, A_7, ], flags={Modified, InputMemory(0), }, shape=[n,m], index=5, debug_index=12, debug_name=A, created_in=Tracing initial graph, created_in_func=None, ) +float(32) Q(32) = const(outputs=[Q_2, v2_4, Q_3, v3_18, ], data=[0], flags={Modified, OutputMemory(0), }, shape=[n,m], index=6, debug_index=14, debug_name=Q, created_in=Tracing initial graph, created_in_func=None, ) +float(32) R(32) = const(outputs=[v2_2, R_2, v2_17, R_3, v3_8, R_4, ], data=[0], flags={Modified, OutputMemory(1), }, shape=[n,n], index=7, debug_index=16, debug_name=R, created_in=Tracing initial graph, created_in_func=None, ) +int(32) j(32) = dim_id(outputs=[A_2, A_3, v2_4, A_6, A_7, v3_18, ], data=[0], shape=[m], index=8, debug_index=18, debug_name=j, created_in=Tracing initial graph, created_in_func=None, ) +int(32) v1_2(32) = const(outputs=[v1_3, ], data=[1], index=9, debug_index=20, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +int(32) v1_3(32) = sub(inputs=[n,v1_2(1)], outputs=[i, ], index=10, debug_index=22, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +int(32) v1_4(32) = const(outputs=[i, ], data=[1], index=11, debug_index=24, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +int(32) v1_5(32) = const(outputs=[i, ], data=[0], index=12, debug_index=26, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +int(32) i(32) = loop(inputs=[v1_5(0),v1_3,v1_4(1)], outputs=[v2_6, v2_14, A_2, v2_2, Q_2, R_2, A_3, v2_4, Q_3, v2_2, R_2, v2_17, R_3, ], index=13, debug_index=28, debug_name=i, created_in=Tracing initial graph, created_in_func=None, ) +{ + float(32) A_2(32) = load(memory=[A], indices=[i,j], outputs=[v2_1, ], data=[0], shape=[m], index=14, debug_index=29, debug_name=A, created_in=Tracing initial graph, created_in_func=None, ) + int(32) v2_0(32) = const(outputs=[v2_1, v2_2, ], data=[1], index=15, debug_index=32, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + float(32) v2_1(32) = dim_norm(inputs=[A_2], outputs=[v2_2, ], data=[0], shape=[v2_0(1)], index=16, debug_index=34, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + store(memory=[R(0.f)], inputs=[v2_1], indices=[i,i], shape=[v2_0(1)], index=17, debug_index=36, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + float(32) A_3(32) = load(memory=[A], indices=[i,j], outputs=[v2_3, ], data=[0], shape=[m], index=18, debug_index=38, debug_name=A, created_in=Tracing initial graph, created_in_func=None, ) + float(32) R_2(32) = load(memory=[R(0.f)], indices=[i,i], outputs=[v2_3, ], data=[0], index=19, debug_index=40, debug_name=R, created_in=Tracing initial graph, created_in_func=None, ) + float(32) v2_3(32) = div(inputs=[A_3,R_2], outputs=[v2_4, ], shape=[m], index=20, debug_index=42, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + store(memory=[Q(0.f)], inputs=[v2_3], indices=[i,j], shape=[m], index=21, debug_index=44, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + int(32) v2_5(32) = const(outputs=[v2_6, ], data=[1], index=22, debug_index=46, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + int(32) v2_6(32) = add(inputs=[i,v2_5(1)], outputs=[v2_8, k, ], index=23, debug_index=48, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + int(32) v2_7(32) = const(outputs=[p, v2_9, ], data=[0], index=24, debug_index=50, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + int(32) v2_8(32) = sub(inputs=[n,v2_6], outputs=[p, Q_2, v2_10, v2_11, k, v2_17, Q_3, v2_13, v2_16, A_5, v2_12, v2_18, v2_14, v2_19, v2_20, R_3, A_4, dot, ], index=25, debug_index=52, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + int(32) v2_9(32) = sub(inputs=[m,v2_7(0)], outputs=[p, Q_2, v2_10, v2_11, k, Q_3, A_5, v2_12, v2_18, v2_19, v2_20, R_3, A_4, ], index=26, debug_index=54, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + int(32) v2_10(32) = dim_id(outputs=[k, ], data=[0], shape=[v2_8,v2_9], index=27, debug_index=56, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + int(32) k(32) = add(inputs=[v2_10,v2_6], outputs=[A_5, v2_20, R_3, A_4, ], shape=[v2_8,v2_9], index=28, debug_index=58, debug_name=k, created_in=Tracing initial graph, created_in_func=None, ) + int(32) v2_11(32) = dim_id(outputs=[p, ], data=[1], shape=[v2_8,v2_9], index=29, debug_index=60, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + int(32) p(32) = add(inputs=[v2_11,v2_7(0)], outputs=[Q_2, Q_3, A_5, v2_20, A_4, ], shape=[v2_8,v2_9], index=30, debug_index=62, debug_name=p, created_in=Tracing initial graph, created_in_func=None, ) + float(32) Q_2(32) = load(memory=[Q(0.f)], indices=[i,p], outputs=[v2_12, ], data=[0], shape=[v2_8,v2_9], index=31, debug_index=64, debug_name=Q, created_in=Tracing initial graph, created_in_func=None, ) + float(32) A_4(32) = load(memory=[A], indices=[k,p], outputs=[v2_12, ], data=[0], shape=[v2_8,v2_9], index=32, debug_index=66, debug_name=A, created_in=Tracing initial graph, created_in_func=None, ) + float(32) v2_12(32) = mul(inputs=[Q_2,A_4], outputs=[dot, ], shape=[v2_8,v2_9], index=33, debug_index=68, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + float(32) dot(32) = dim_sum(inputs=[v2_12], outputs=[v2_17, ], data=[1], shape=[v2_8], index=34, debug_index=70, debug_name=dot, created_in=Tracing initial graph, created_in_func=None, ) + int(32) v2_13(32) = dim_id(outputs=[v2_14, ], data=[0], shape=[v2_8], index=35, debug_index=72, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + int(32) v2_14(32) = add(inputs=[v2_13,i], outputs=[v2_16, ], shape=[v2_8], index=36, debug_index=74, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + int(32) v2_15(32) = const(outputs=[v2_16, ], data=[1], index=37, debug_index=76, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + int(32) v2_16(32) = add(inputs=[v2_14,v2_15(1)], outputs=[v2_17, ], shape=[v2_8], index=38, debug_index=78, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + store(memory=[R(0.f)], inputs=[dot], indices=[v2_16,i], shape=[v2_8], index=39, debug_index=80, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + float(32) A_5(32) = load(memory=[A], indices=[k,p], outputs=[v2_19, ], data=[0], shape=[v2_8,v2_9], index=40, debug_index=82, debug_name=A, created_in=Tracing initial graph, created_in_func=None, ) + float(32) Q_3(32) = load(memory=[Q(0.f)], indices=[i,p], outputs=[v2_18, ], data=[0], shape=[v2_8,v2_9], index=41, debug_index=84, debug_name=Q, created_in=Tracing initial graph, created_in_func=None, ) + float(32) R_3(32) = load(memory=[R(0.f)], indices=[k,i], outputs=[v2_18, ], data=[0], shape=[v2_8,v2_9], index=42, debug_index=86, debug_name=R, created_in=Tracing initial graph, created_in_func=None, ) + float(32) v2_18(32) = mul(inputs=[Q_3,R_3], outputs=[v2_19, ], shape=[v2_8,v2_9], index=43, debug_index=88, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + float(32) v2_19(32) = sub(inputs=[A_5,v2_18], outputs=[v2_20, ], shape=[v2_8,v2_9], index=44, debug_index=90, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) + store(memory=[A], inputs=[v2_19], indices=[k,p], shape=[v2_8,v2_9], index=45, debug_index=92, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +} +int(32) v3_0(32) = const(outputs=[v3_1, ], data=[1], index=46, debug_index=30, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +int(32) v3_1(32) = sub(inputs=[n,v3_0(1)], outputs=[A_6, ], index=47, debug_index=96, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +float(32) A_6(32) = load(memory=[A], indices=[v3_1,j], outputs=[v3_3, ], data=[0], shape=[m], index=48, debug_index=98, debug_name=A, created_in=Tracing initial graph, created_in_func=None, ) +int(32) v3_2(32) = const(outputs=[v3_8, v3_3, ], data=[1], index=49, debug_index=100, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +float(32) v3_3(32) = dim_norm(inputs=[A_6], outputs=[v3_8, ], data=[0], shape=[v3_2(1)], index=50, debug_index=102, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +int(32) v3_4(32) = const(outputs=[v3_5, ], data=[1], index=51, debug_index=104, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +int(32) v3_5(32) = sub(inputs=[n,v3_4(1)], outputs=[v3_8, ], index=52, debug_index=106, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +int(32) v3_6(32) = const(outputs=[v3_7, ], data=[1], index=53, debug_index=108, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +int(32) v3_7(32) = sub(inputs=[n,v3_6(1)], outputs=[v3_8, ], index=54, debug_index=110, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +store(memory=[R(0.f)], inputs=[v3_3], indices=[v3_7,v3_5], shape=[v3_2(1)], index=55, debug_index=112, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +int(32) v3_9(32) = const(outputs=[v3_10, ], data=[1], index=56, debug_index=114, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +int(32) v3_10(32) = sub(inputs=[n,v3_9(1)], outputs=[A_7, ], index=57, debug_index=116, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +float(32) A_7(32) = load(memory=[A], indices=[v3_10,j], outputs=[v3_15, ], data=[0], shape=[m], index=58, debug_index=118, debug_name=A, created_in=Tracing initial graph, created_in_func=None, ) +int(32) v3_11(32) = const(outputs=[v3_12, ], data=[1], index=59, debug_index=120, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +int(32) v3_12(32) = sub(inputs=[n,v3_11(1)], outputs=[R_4, ], index=60, debug_index=122, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +int(32) v3_13(32) = const(outputs=[v3_14, ], data=[1], index=61, debug_index=124, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +int(32) v3_14(32) = sub(inputs=[n,v3_13(1)], outputs=[R_4, ], index=62, debug_index=126, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +float(32) R_4(32) = load(memory=[R(0.f)], indices=[v3_14,v3_12], outputs=[v3_15, ], data=[0], index=63, debug_index=128, debug_name=R, created_in=Tracing initial graph, created_in_func=None, ) +float(32) v3_15(32) = div(inputs=[A_7,R_4], outputs=[v3_18, ], shape=[m], index=64, debug_index=130, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +int(32) v3_16(32) = const(outputs=[v3_17, ], data=[1], index=65, debug_index=132, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +int(32) v3_17(32) = sub(inputs=[n,v3_16(1)], outputs=[v3_18, ], index=66, debug_index=134, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +store(memory=[Q(0.f)], inputs=[v3_15], indices=[v3_17,j], shape=[m], index=67, debug_index=136, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) +region_end(index=68, debug_index=138, debug_name=blur, created_in=Tracing initial graph, created_in_func=None, ) + diff --git a/TensorFrost/IR/include/Common.h b/TensorFrost/IR/include/Common.h new file mode 100644 index 00000000..5587025d --- /dev/null +++ b/TensorFrost/IR/include/Common.h @@ -0,0 +1,109 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace TensorFrost { +extern "C" { + enum TFType { + Float, + Uint, + Int, + Bool, + Tuple, + None, + }; + + struct TFDataFormat { + TFType type; + size_t size; + + bool operator==(const TFDataFormat& other) const; + bool operator!=(const TFDataFormat& other) const; + int GetHash() const; + bool operator<(const TFDataFormat& other) const; + bool operator>(const TFDataFormat& other) const; + }; + +#define TFTypeNone TFDataFormat{TFType::None, 0} +#define TFTypeTuple TFDataFormat{TFType::Tuple, 0} +#define TFTypeBool32 TFDataFormat{TFType::Bool, 32} +#define TFTypeFloat32 TFDataFormat{TFType::Float, 32} +#define TFTypeInt32 TFDataFormat{TFType::Int, 32} +#define TFTypeUint32 TFDataFormat{TFType::Uint, 32} +} + +// Utility class to automatically resize and set elements in a vector +template +class auto_vector : public std::vector { +public: + void set_element(size_t index, const T& value) { + if (index >= this->size()) { + this->resize(index + 1); + } + (*this)[index] = value; + } +}; + +template> +struct VecHash { + size_t operator()(const std::vector& v) const noexcept { + size_t h = 0; + H hasher; + for (const auto& x : v) + h ^= hasher(x) + 0x9e3779b9 + (h << 6) + (h >> 2); + return h; + } +}; + +enum class ArgType { + Input, + Index, + Memory, + Shape, //must be last + Count, +}; + +inline std::string ToString(ArgType type) { + switch (type) { + case ArgType::Input: return "Input"; + case ArgType::Index: return "Index"; + case ArgType::Shape: return "Shape"; + case ArgType::Memory: return "Memory"; + default: return "Unknown"; + } +} + +inline std::string ToString(const TFDataFormat& format) { + switch (format.type) { + case TFType::Float: return "Float" + std::to_string(format.size); + case TFType::Uint: return "Uint" + std::to_string(format.size); + case TFType::Int: return "Int" + std::to_string(format.size); + case TFType::Bool: return "Bool" + std::to_string(format.size); + case TFType::Tuple: return "Tuple"; + case TFType::None: return "None"; + default: return "Unknown"; + } +} + +using uint = unsigned int; + +struct Op; +struct Arguments; +struct OpBlock; +class OpBlockIterator; +struct ArgumentManager; +struct ShapeArgs; +struct Argument; + +using Attribute = std::variant; +using AttributeMap = std::unordered_map; + +} diff --git a/TensorFrost/IR/include/ExecutionContext.h b/TensorFrost/IR/include/ExecutionContext.h new file mode 100644 index 00000000..7d8a5292 --- /dev/null +++ b/TensorFrost/IR/include/ExecutionContext.h @@ -0,0 +1,24 @@ +#pragma once + +#include "Common.h" + +namespace TensorFrost { + +struct ExecutionContext { + std::unique_ptr base_block; + OpBlock* current_block; + std::vector stack; + + ExecutionContext(); + + void BeginBlock(Op* op); + void EndBlock(); + + Op &AddOp(std::unique_ptr op); +}; + +void StartExecutionContext(); +ExecutionContext* GetContext(); +void EndExecutionContext(); + +} diff --git a/TensorFrost/IR/include/IR.h b/TensorFrost/IR/include/IR.h deleted file mode 100644 index 9f80eb11..00000000 --- a/TensorFrost/IR/include/IR.h +++ /dev/null @@ -1,114 +0,0 @@ -#pragma once -#include -#include -#include -#include -#include - -namespace TensorFrost { - -using uint = unsigned int; - -struct Type { - std::string dtype; - size_t size = 0; // size in bytes -}; - -struct Value; -struct Op; -struct Arguments; - -using Attribute = std::variant; -using AttributeMap = std::unordered_map; - -struct Value { - int id; - Type type; - Op* producer = nullptr; -}; - -struct Block { - std::list> ops; - Op* append(std::unique_ptr op) { - ops.emplace_back(std::move(op)); - return ops.back().get(); - } -}; - -struct ExecutionContext { - std::unique_ptr base_block; - Block* current_block = base_block.get(); - std::vector stack; - - void BeginBlock(Op* op) { - stack.push_back(current_block); - current_block = new Block(); - } - - void EndBlock() { - if (!stack.empty()) { - current_block = stack.back(); - stack.pop_back(); - } else { - current_block = nullptr; - } - } - - void AddOp(std::unique_ptr op) { - if (!current_block) { - throw std::runtime_error("No current block to add operation to"); - } - current_block->append(std::move(op)); - } -}; - -enum class ArgType { - Input, - Output, - Memory, - Shape, - Count -}; - -struct Argument { - ArgType type; - Value* from_value = nullptr; - Op* target_op = nullptr; - int index = 0; -}; - -struct Op { - static ExecutionContext* current_context; - std::string opcode; - std::vector arguments; - std::vector outputs; - std::vector> blocks; - AttributeMap attributes; - - Op(std::string op_name) : opcode(std::move(op_name)) { - if (!current_context) { - throw std::runtime_error("No current execution context set for operation creation"); - } - current_context->AddOp(std::make_unique(*this)); - } - - std::string operator std::string() const { - return opcode; - } - - Op& binary(Op& other, const std::string& op_name) { - Op* new_op = new Op(op_name); - Argument arg1{ArgType::Input, &outputs[0], this, 0}; - Argument arg2{ArgType::Input, &other.outputs[0], &other, 0}; - new_op->arguments.push_back(arg1); - new_op->arguments.push_back(arg2); - new_op->outputs.push_back(Value{0, Type{"float", 4}, new_op}); - return *new_op; - } -}; - - - - -} // namespace ir - diff --git a/TensorFrost/IR/include/Operation.h b/TensorFrost/IR/include/Operation.h new file mode 100644 index 00000000..7d259908 --- /dev/null +++ b/TensorFrost/IR/include/Operation.h @@ -0,0 +1,23 @@ +#pragma once + +#include "Common.h" + +namespace TensorFrost { + +struct Op { + std::string opcode; + std::unique_ptr args; + AttributeMap attributes; + TFDataFormat type; + std::vector> blocks; + + Op(std::string op_name); + Op(int value); + Op(uint value); + Op(float value); + Op(bool value); +}; + + +} // namespace ir + diff --git a/TensorFrost/IR/include/OperationArguments.h b/TensorFrost/IR/include/OperationArguments.h new file mode 100644 index 00000000..bcc42f00 --- /dev/null +++ b/TensorFrost/IR/include/OperationArguments.h @@ -0,0 +1,42 @@ +#pragma once +#include "Common.h" + +namespace TensorFrost { + +struct Argument { + ArgType type; + Op* from = nullptr; + Op* to = nullptr; + int index = 0; +}; + +struct Arguments { + Op* parent_op = nullptr; + auto_vector> inputs; + std::set> used_at; + + void AddInput(ArgType type, Op* from, int index = 0); + bool CheckValidity(bool throw_error = false) const; +}; + +struct ShapeArgs : Arguments { + std::vector TryGetShape(int default_value = 256) const; + float GetSizeEstimate(); + void ExpandDimensionsTo(int new_dim); + + bool CompareShape(const ShapeArgs& other, bool throw_error = false) const { + //TODO: Implement shape comparison logic + } +}; + +struct ArgumentManager { + Op* parent_op = nullptr; + std::array, (int)ArgType::Count> type_args; + + ArgumentManager(Op* parent); + void AddArgument(Op* from, ArgType type, int index = 0); + void SetAsOutput(Argument *arg); + void SetArguments(ArgType type, std::vector args); +}; + +} \ No newline at end of file diff --git a/TensorFrost/IR/include/OperationBlocks.h b/TensorFrost/IR/include/OperationBlocks.h new file mode 100644 index 00000000..7ad8705f --- /dev/null +++ b/TensorFrost/IR/include/OperationBlocks.h @@ -0,0 +1,36 @@ +#pragma once +#include "Operation.h" + +namespace TensorFrost { + +struct OpBlock { + std::list> ops; + Op* append(std::unique_ptr op); +}; + +class OpBlockIterator { +public: + using OpIter = std::list>::iterator; + using OpRevIter = std::list>::reverse_iterator; + + struct Frame { + OpBlock* block; + OpIter it; + OpIter end; + }; + + OpBlockIterator(OpBlock* root); + + Op* next(); // Move to next Op in depth-first order + Op* prev(); // Move to previous Op in depth-first order + bool down(); // Enter the first sub-block of current Op (if any) + bool up(); // Exit to parent block + + Op* current() const; + +private: + std::vector stack; + Op* current_op; +}; + +} \ No newline at end of file diff --git a/TensorFrost/IR/include/OperationRegistry.h b/TensorFrost/IR/include/OperationRegistry.h new file mode 100644 index 00000000..6b72260e --- /dev/null +++ b/TensorFrost/IR/include/OperationRegistry.h @@ -0,0 +1,20 @@ +#pragma once +#include "Common.h" + +namespace TensorFrost { + +using OverloadsMap = std::unordered_map, TFDataFormat, VecHash>; + +struct OpSpec { + std::string name; + OverloadsMap overloads; + + OpSpec(std::string op_name, OverloadsMap overloads_list); + + TFDataFormat GetOutputType(const std::vector& args) const; +}; + +void RegisterOperation(const OpSpec& spec); +OpSpec* GetOpSpec(const std::string& name); + +} \ No newline at end of file diff --git a/TensorFrost/IR/include/Operations.h b/TensorFrost/IR/include/Operations.h deleted file mode 100644 index bab643d3..00000000 --- a/TensorFrost/IR/include/Operations.h +++ /dev/null @@ -1,22 +0,0 @@ -#pragma once -#include -#include -#include -#include -#include - -namespace TensorFrost { - -struct OpSpec { - std::string name; - int arity = -1; // -1 variadic -}; - -inline std::unordered_map& registry() { - static std::unordered_map r; - return r; -} - -inline void reg(OpSpec s) { registry()[s.name] = std::move(s); } - -} // namespace ir \ No newline at end of file diff --git a/TensorFrost/IR/include/Overloads.h b/TensorFrost/IR/include/Overloads.h index 5d6fb342..97262d4d 100644 --- a/TensorFrost/IR/include/Overloads.h +++ b/TensorFrost/IR/include/Overloads.h @@ -1,64 +1,236 @@ #pragma once -#include "IR.h" +#include "Operation.h" namespace TensorFrost { -using ArgValue = std::variant; -// Arithmetic operations -Value& operator+(ArgValue& a, ArgValue& b); -Value& operator-(ArgValue& a, ArgValue& b); -Value& operator*(ArgValue& a, ArgValue& b); -Value& operator/(ArgValue& a, ArgValue& b); -Value& operator%(ArgValue& a, ArgValue& b); +Op& make_op(std::string op, std::vector mem, std::vector ids, std::vector args, std::vector shape); -// Bitwise operations -Value& operator&(ArgValue& a, ArgValue& b); -Value& operator|(ArgValue& a, ArgValue& b); -Value& operator^(ArgValue& a, ArgValue& b); -Value& operator<<(ArgValue& a, ArgValue& b); -Value& operator>>(ArgValue& a, ArgValue& b); -Value& operator~(ArgValue& a); +template +Op& func_op(std::string op, const Args&... args) { + std::vector mem; + std::vector ids; + std::vector args_vec = {&args...}; + std::vector shape; + return make_op(op, mem, ids, args_vec, shape); +} -// Comparison operations -Value& operator==(ArgValue& a, ArgValue& b); -Value& operator!=(ArgValue& a, ArgValue& b); -Value& operator<(ArgValue& a, ArgValue& b); -Value& operator<=(ArgValue& a, ArgValue& b); -Value& operator>(ArgValue& a, ArgValue& b); -Value& operator>=(ArgValue& a, ArgValue& b); +Op& constant(int value); +Op& constant(uint value); +Op& constant(float value); +Op& constant(bool value); + +template +concept Num = std::is_arithmetic_v>; + +template +inline Op& as_op(T v) +{ + using D = std::remove_cvref_t; + using Target = + std::conditional_t, bool, + std::conditional_t, float, + std::conditional_t, unsigned int, + int>>>; + return constant(static_cast(v)); +} -// Logical operations -Value& operator&&(ArgValue& a, ArgValue& b); -Value& operator||(ArgValue& a, ArgValue& b); -Value& operator!(ArgValue& a); +#define UNARY_OPERATOR(op, name) \ +template \ +Op& operator op(const T& a) { \ + return func_op(name, as_op(a)); \ +} -// Increment and decrement operations -Value& operator++(ArgValue& a); -Value& operator--(ArgValue& a); +#define BINARY_OPERATOR(op, name) \ +template \ +Op& operator op(const T& a, const U& b) { \ + return func_op(name, as_op(a), as_op(b)); \ +} -// Mathematical functions -Value& abs(ArgValue& a); -Value& sqrt(ArgValue& a); -Value& pow(ArgValue& base, ArgValue& exponent); -Value& sin(ArgValue& a); -Value& cos(ArgValue& a); -Value& tan(ArgValue& a); -Value& log(ArgValue& a); -Value& exp(ArgValue& a); -Value& min(ArgValue& a, ArgValue& b); -Value& max(ArgValue& a, ArgValue& b); -Value& clamp(ArgValue& value, ArgValue& min_val, ArgValue& max_val); -Value& round(ArgValue& a); -Value& floor(ArgValue& a); -Value& ceil(ArgValue& a); -Value& select(ArgValue& condition, ArgValue& true_value, ArgValue& false_value); +#define UNARY_FUNCTION(name, opname) \ +template \ +Op& name(const T& a) { \ + return func_op(opname, as_op(a)); \ +} +#define BINARY_FUNCTION(name, opname) \ +template \ +Op& name(const T& a, const U& b) { \ + return func_op(opname, as_op(a), as_op(b)); \ +} +#define TERNARY_FUNCTION(name, opname) \ +template \ +Op& name(const T& cond, const U& x, const V& y) { \ + return func_op(opname, as_op(cond), as_op(x), as_op(y)); \ +} +UNARY_OPERATOR(+, "pos") +UNARY_OPERATOR(-, "neg") +UNARY_OPERATOR(~, "not") +UNARY_OPERATOR(!, "lnot") + +BINARY_OPERATOR(+, "add") +BINARY_OPERATOR(-, "sub") +BINARY_OPERATOR(*, "mul") +BINARY_OPERATOR(/, "div") +BINARY_OPERATOR(%, "mod") +BINARY_OPERATOR(&, "and") +BINARY_OPERATOR(|, "or") +BINARY_OPERATOR(^, "xor") +BINARY_OPERATOR(<<, "lshift") +BINARY_OPERATOR(>>, "rshift") +BINARY_OPERATOR(==, "eq") +BINARY_OPERATOR(!=, "neq") +BINARY_OPERATOR(<, "lt") +BINARY_OPERATOR(<=, "lte") +BINARY_OPERATOR(>, "gt") +BINARY_OPERATOR(>=, "gte") +BINARY_OPERATOR(&&, "land") +BINARY_OPERATOR(||, "lor") + +UNARY_FUNCTION(copy, "copy") +UNARY_FUNCTION(sin, "sin") +UNARY_FUNCTION(cos, "cos") +UNARY_FUNCTION(tan, "tan") +UNARY_FUNCTION(asin, "asin") +UNARY_FUNCTION(acos, "acos") +UNARY_FUNCTION(atan, "atan") +UNARY_FUNCTION(sinh, "sinh") +UNARY_FUNCTION(cosh, "cosh") +UNARY_FUNCTION(tanh, "tanh") +UNARY_FUNCTION(asinh, "asinh") +UNARY_FUNCTION(acosh, "acosh") +UNARY_FUNCTION(atanh, "atanh") +UNARY_FUNCTION(exp, "exp") +UNARY_FUNCTION(log, "log") +UNARY_FUNCTION(log2, "log2") +UNARY_FUNCTION(exp2, "exp2") +UNARY_FUNCTION(sqrt, "sqrt") +UNARY_FUNCTION(sqr, "sqr") +UNARY_FUNCTION(rsqrt, "rsqrt") +UNARY_FUNCTION(rcp, "rcp") +UNARY_FUNCTION(abs, "abs") +UNARY_FUNCTION(sign, "sign") +UNARY_FUNCTION(floor, "floor") +UNARY_FUNCTION(ceil, "ceil") +UNARY_FUNCTION(round, "round") +UNARY_FUNCTION(trunc, "trunc") +UNARY_FUNCTION(frac, "frac") +UNARY_FUNCTION(pcg, "pcg") +UNARY_FUNCTION(pcgf, "pcgf") +UNARY_FUNCTION(reversebits, "reversebits") +UNARY_FUNCTION(tofloat, "tofloat") +UNARY_FUNCTION(toint, "toint") +UNARY_FUNCTION(touint, "touint") +UNARY_FUNCTION(tobool, "tobool") +UNARY_FUNCTION(asfloat, "asfloat") +UNARY_FUNCTION(asint, "asint") +UNARY_FUNCTION(asuint, "asuint") +UNARY_FUNCTION(clamp, "clamp") + +BINARY_FUNCTION(pow, "pow") +BINARY_FUNCTION(min, "min") +BINARY_FUNCTION(max, "max") +BINARY_FUNCTION(mod, "mod") +BINARY_FUNCTION(modf, "modf") +BINARY_FUNCTION(atan2, "atan2") +BINARY_FUNCTION(grad, "backwards_grad") + +TERNARY_FUNCTION(lerp, "lerp") +TERNARY_FUNCTION(smoothstep, "smoothstep") +TERNARY_FUNCTION(select, "ternary") +TERNARY_FUNCTION(fma, "fma") +// Arithmetic operations +Op& operator+(const Op& a, const Op& b); +Op& operator-(const Op& a, const Op& b); +Op& operator*(const Op& a, const Op& b); +Op& operator/(const Op& a, const Op& b); +Op& operator%(const Op& a, const Op& b); +// Bitwise operations +Op& operator&(const Op& a, const Op& b); +Op& operator|(const Op& a, const Op& b); +Op& operator^(const Op& a, const Op& b); +Op& operator<<(const Op& a, const Op& b); +Op& operator>>(const Op& a, const Op& b); +Op& operator~(const Op& a); +// Comparison operations +Op& operator==(const Op& a, const Op& b); +Op& operator!=(const Op& a, const Op& b); +Op& operator<(const Op& a, const Op& b); +Op& operator<=(const Op& a, const Op& b); +Op& operator>(const Op& a, const Op& b); +Op& operator>=(const Op& a, const Op& b); + +// Logical operations +Op& operator&&(const Op& a, const Op& b); +Op& operator||(const Op& a, const Op& b); +Op& operator!(const Op& a); +// Increment and decrement operations +Op& operator++(const Op& a); +Op& operator--(const Op& a); + +Op& operator+=(const Op& a, const Op& b); +Op& operator-=(const Op& a, const Op& b); + +Op& copy(const Op& a); +Op& sin(const Op& a); +Op& cos(const Op& a); +Op& tan(const Op& a); +Op& asin(const Op& a); +Op& acos(const Op& a); +Op& atan(const Op& a); +Op& sinh(const Op& a); +Op& cosh(const Op& a); +Op& tanh(const Op& a); +Op& asinh(const Op& a); +Op& acosh(const Op& a); +Op& atanh(const Op& a); +Op& exp(const Op& a); +Op& log(const Op& a); +Op& log2(const Op& a); +Op& exp2(const Op& a); +Op& sqrt(const Op& a); +Op& sqr(const Op& a); +Op& rsqrt(const Op& a); +Op& rcp(const Op& a); +Op& abs(const Op& a); +Op& sign(const Op& a); +Op& floor(const Op& a); +Op& ceil(const Op& a); +Op& round(const Op& a); +Op& trunc(const Op& a); +Op& frac(const Op& a); + +Op& pcg(const Op& a); +Op& pcgf(const Op& a); + +Op& reversebits(const Op& a); + +Op& tofloat(const Op& a); +Op& toint(const Op& a); +Op& touint(const Op& a); +Op& tobool(const Op& a); + +Op& asfloat(const Op& a); +Op& asint(const Op& a); +Op& asuint(const Op& a); + +Op& clamp(const Op& x, const Op& min, const Op& max); +Op& pow(const Op& x, const Op& y); +Op& min(const Op& x, const Op& y); +Op& max(const Op& x, const Op& y); +Op& mod(const Op& x, const Op& y); +Op& modf(const Op& x, const Op& y); +Op& atan2(const Op& x, const Op& y); +Op& grad(const Op& x, const Op& wrt); +Op& lerp(const Op& x, const Op& y, const Op& a); +Op& smoothstep(const Op& a, const Op& b, const Op& x); +Op& select(const Op& cond, const Op& x, const Op& y); +Op& fma(const Op& x, const Op& y, const Op& z); } diff --git a/TensorFrost/IR/src/Common.cpp b/TensorFrost/IR/src/Common.cpp new file mode 100644 index 00000000..b0fc8c64 --- /dev/null +++ b/TensorFrost/IR/src/Common.cpp @@ -0,0 +1,24 @@ +#include "../include/Common.h" + +using namespace TensorFrost; + +bool TFDataFormat::operator==(const TFDataFormat &other) const { + return type == other.type && size == other.size; +} + +bool TFDataFormat::operator!=(const TFDataFormat &other) const { + return !(*this == other); +} + +int TFDataFormat::GetHash() const { + return (int)type << 16 | (int)size; +} + +bool TFDataFormat::operator<(const TFDataFormat &other) const { + return GetHash() < other.GetHash(); +} + +bool TFDataFormat::operator>(const TFDataFormat &other) const { + return GetHash() > other.GetHash(); +} + diff --git a/TensorFrost/IR/src/ExecutionContext.cpp b/TensorFrost/IR/src/ExecutionContext.cpp new file mode 100644 index 00000000..bc941210 --- /dev/null +++ b/TensorFrost/IR/src/ExecutionContext.cpp @@ -0,0 +1,47 @@ +#include "../include/ExecutionContext.h" +#include "../include/Operation.h" + +using namespace TensorFrost; + +ExecutionContext::ExecutionContext(): base_block(std::make_unique()), current_block(base_block.get()) {} + +void ExecutionContext::BeginBlock(Op *op) { + stack.push_back(current_block); + current_block = new OpBlock(); +} + +void ExecutionContext::EndBlock() { + if (!stack.empty()) { + current_block = stack.back(); + stack.pop_back(); + } else { + throw std::runtime_error("No block to end"); + } +} + +Op &ExecutionContext::AddOp(std::unique_ptr op) { + current_block->append(std::move(op)); + return *current_block->ops.back(); +} + +ExecutionContext* current_context = nullptr; + +void StartExecutionContext() { + if (current_context) { + throw std::runtime_error("Execution context already started"); + } + current_context = new ExecutionContext(); +} + +ExecutionContext* GetContext() { + return current_context; +} + +void EndExecutionContext() { + if (!current_context) { + throw std::runtime_error("No execution context to end"); + } + delete current_context; + current_context = nullptr; +} + diff --git a/TensorFrost/IR/src/IR.cpp b/TensorFrost/IR/src/IR.cpp deleted file mode 100644 index e69de29b..00000000 diff --git a/TensorFrost/IR/src/Operation.cpp b/TensorFrost/IR/src/Operation.cpp new file mode 100644 index 00000000..8ef73e6f --- /dev/null +++ b/TensorFrost/IR/src/Operation.cpp @@ -0,0 +1,30 @@ +#include "../include/Operation.h" +#include "../include/OperationArguments.h" +#include "../include/Overloads.h" + +using namespace TensorFrost; + +Op::Op(std::string op_name): opcode(std::move(op_name)) { + args = std::make_unique(this); + type = TFTypeNone; +} + +Op::Op(int value) : Op("const") { + attributes["value"] = value; + type = TFTypeInt32; +} + +Op::Op(uint value) : Op("const") { + attributes["value"] = value; + type = TFTypeUint32; +} + +Op::Op(float value) : Op("const") { + attributes["value"] = value; + type = TFTypeFloat32; +} + +Op::Op(bool value) : Op(std::string("const")) { + attributes["value"] = value; + type = TFTypeBool32; +} diff --git a/TensorFrost/IR/src/OperationArguments.cpp b/TensorFrost/IR/src/OperationArguments.cpp new file mode 100644 index 00000000..58e75438 --- /dev/null +++ b/TensorFrost/IR/src/OperationArguments.cpp @@ -0,0 +1,44 @@ +#include "../include/OperationArguments.h" +#include "../include/Operation.h" + +using namespace TensorFrost; + +void Arguments::AddInput(ArgType type, Op *from, int index) { + inputs.set_element(index, std::make_unique(Argument{type, from, parent_op, index})); + from->args->SetAsOutput(inputs[index].get()); +} + +bool Arguments::CheckValidity(bool throw_error) const { + for (const auto& input : inputs) { + if (!input || !input->from) { + if (throw_error) { + throw std::runtime_error("Invalid argument"); + } + return false; + } + } + return true; +} + +ArgumentManager::ArgumentManager(Op *parent): parent_op(parent) { + for (int i = 0; i < (int)ArgType::Shape; ++i) { + type_args[i] = std::make_unique(); + type_args[i]->parent_op = parent; + } + type_args[(int)ArgType::Shape] = std::make_unique(); + type_args[(int)ArgType::Shape]->parent_op = parent; +} + +void ArgumentManager::AddArgument(Op *from, ArgType type, int index) { + type_args[(int)type]->AddInput(type, from, index); +} + +void ArgumentManager::SetAsOutput(Argument *arg) { + type_args[(int)arg->type]->used_at.insert({arg->index, arg}); +} + +void ArgumentManager::SetArguments(ArgType type, std::vector args) { + for (size_t i = 0; i < args.size(); ++i) { + AddArgument(args[i], type, (int)i); + } +} diff --git a/TensorFrost/IR/src/OperationBlocks.cpp b/TensorFrost/IR/src/OperationBlocks.cpp new file mode 100644 index 00000000..27ef0e3c --- /dev/null +++ b/TensorFrost/IR/src/OperationBlocks.cpp @@ -0,0 +1,83 @@ +#include "../include/OperationBlocks.h" + +using namespace TensorFrost; + +Op* OpBlock::append(std::unique_ptr op) { + ops.emplace_back(std::move(op)); + return ops.back().get(); +} + +OpBlockIterator::OpBlockIterator(OpBlock* root) : current_op(nullptr) { + if (root && !root->ops.empty()) { + stack.push_back({root, root->ops.begin(), root->ops.end()}); + current_op = stack.back().it->get(); + } +} + +Op* OpBlockIterator::current() const { + return current_op; +} + +Op* OpBlockIterator::next() { + if (stack.empty()) return nullptr; + // If current op has sub-blocks, go down + if (!current_op->blocks.empty() && current_op->blocks[0] && !current_op->blocks[0]->ops.empty()) { + OpBlock* sub = current_op->blocks[0].get(); + stack.push_back({sub, sub->ops.begin(), sub->ops.end()}); + current_op = stack.back().it->get(); + return current_op; + } + // Otherwise, go to next op in current block or up + while (!stack.empty()) { + auto& frame = stack.back(); + ++frame.it; + if (frame.it != frame.end) { + current_op = frame.it->get(); + return current_op; + } else { + stack.pop_back(); + } + } + current_op = nullptr; + return nullptr; +} + +Op* OpBlockIterator::prev() { + if (stack.empty()) return nullptr; + auto& frame = stack.back(); + if (frame.it == frame.block->ops.begin()) { + stack.pop_back(); + if (!stack.empty()) { + current_op = stack.back().it->get(); + return current_op; + } + current_op = nullptr; + return nullptr; + } + --frame.it; + // Go to the deepest last op in sub-blocks if any + Op* op = frame.it->get(); + while (!op->blocks.empty() && op->blocks[0] && !op->blocks[0]->ops.empty()) { + OpBlock* sub = op->blocks[0].get(); + stack.push_back({sub, --sub->ops.end(), sub->ops.end()}); + op = stack.back().it->get(); + } + current_op = op; + return current_op; +} + +bool OpBlockIterator::down() { + if (!current_op || current_op->blocks.empty() || !current_op->blocks[0] || current_op->blocks[0]->ops.empty()) + return false; + OpBlock* sub = current_op->blocks[0].get(); + stack.push_back({sub, sub->ops.begin(), sub->ops.end()}); + current_op = stack.back().it->get(); + return true; +} + +bool OpBlockIterator::up() { + if (stack.size() <= 1) return false; + stack.pop_back(); + current_op = stack.back().it->get(); + return true; +} \ No newline at end of file diff --git a/TensorFrost/IR/src/OperationRegistry.cpp b/TensorFrost/IR/src/OperationRegistry.cpp new file mode 100644 index 00000000..50b9df73 --- /dev/null +++ b/TensorFrost/IR/src/OperationRegistry.cpp @@ -0,0 +1,90 @@ +#include "../include/Operation.h" +#include "../include/OperationRegistry.h" + +using namespace TensorFrost; +using namespace std; + +OpSpec::OpSpec(std::string op_name, OverloadsMap overloads_list) { + name = std::move(op_name); + overloads = std::move(overloads_list); +} + +TFDataFormat OpSpec::GetOutputType(const std::vector &args) const { + auto it = overloads.find(args); + if (it == overloads.end()) { + throw std::runtime_error("No overload found for operation: " + name + " with args: " + to_string(args.size())); + } + return it->second; +} + +static const std::unordered_map tok = { + {"f", TFDataFormat::TFTypeFloat32}, + {"i", TFDataFormat::TFTypeInt32}, + {"u", TFDataFormat::TFTypeUint32}, + {"tuple", TFDataFormat::TFTypeTuple}, + {"b", TFDataFormat::TFTypeBool32}, + {"void", TFDataFormat::TFTypeNone}, +}; + +static std::string trim(std::string_view s) { + size_t a = 0, b = s.size(); + while (a < b && std::isspace(static_cast(s[a]))) ++a; + while (b > a && std::isspace(static_cast(s[b - 1]))) --b; + return std::string{s.substr(a, b - a)}; +} + +OverloadsMap ovr(const std::string& input) { + OverloadsMap out; + std::stringstream ss(input); + std::string stmt; + while (std::getline(ss, stmt, ';')) { + stmt = trim(stmt); + if (stmt.empty()) continue; + auto l = stmt.find('('), r = stmt.find(')'); + if (l == std::string::npos || r == std::string::npos || r < l) throw std::runtime_error("Overload syntax error: " + stmt); + auto tgt = trim(stmt.substr(0, l)); + auto args = stmt.substr(l + 1, r - l - 1); + std::vector key; + std::stringstream as(args); + std::string tokarg; + while (std::getline(as, tokarg, ',')) { + tokarg = trim(tokarg); + key.push_back(tok.at(tokarg)); + } + out.emplace(std::move(key), tok.at(tgt)); + } + return out; +} + +vector default_operations = { + OpSpec("add", ovr("f(f,f); u(u,u); i(i,i)")), + OpSpec("sub", ovr("f(f,f); u(u,u); i(i,i)")), + OpSpec("mul", ovr("f(f,f); u(u,u); i(i,i)")), + OpSpec("div", ovr("f(f,f); u(u,u); i(i,i)")), + + OpSpec("parallel", ovr("tuple()")), +}; + +std::unordered_map CreateOperationRegistry() { + std::unordered_map registry; + for (const auto& op : default_operations) { + registry[op.name] = op; + } + return registry; +} + +std::unordered_map operation_registry = CreateOperationRegistry(); + +void TensorFrost::RegisterOperation(const OpSpec &spec) { + if (operation_registry.contains(spec.name)) { + throw std::runtime_error("Operation already registered: " + spec.name); + } + operation_registry[spec.name] = spec; +} + +OpSpec* TensorFrost::GetOpSpec(const std::string &name) { + if (!operation_registry.contains(name)) { + throw std::runtime_error("Operation not found: " + name); + } + return &operation_registry[name]; +} diff --git a/TensorFrost/IR/src/Overloads.cpp b/TensorFrost/IR/src/Overloads.cpp index e69de29b..ce2fc259 100644 --- a/TensorFrost/IR/src/Overloads.cpp +++ b/TensorFrost/IR/src/Overloads.cpp @@ -0,0 +1,149 @@ +#include "../include/Overloads.h" +#include "../include/ExecutionContext.h" +#include "../include/OperationRegistry.h" +#include "../include/OperationArguments.h" + +using namespace TensorFrost; +using namespace std; + +// General function to create an Op instance in the current execution context +Op& make_op(string op, vector mem, vector ids, vector args, vector shape) { + OpSpec* spec = GetOpSpec(op); + vector arg_types; + for (const auto& arg : args) { + arg_types.push_back(arg->type); + } + TFDataFormat output_type = spec->GetOutputType(arg_types); + Op* op_instance = new Op(op); + op_instance->type = output_type; + op_instance->args->SetArguments(ArgType::Memory, mem); + op_instance->args->SetArguments(ArgType::Index, ids); + op_instance->args->SetArguments(ArgType::Input, args); + op_instance->args->SetArguments(ArgType::Shape, shape); + return GetContext()->AddOp(std::unique_ptr(op_instance)); +} + +Op& constant(int value) { + Op& const_op = func_op("const"); + const_op.attributes["value"] = value; + const_op.type = TFTypeInt32; + return const_op; +} + +Op& constant(uint value) { + Op& const_op = func_op("const"); + const_op.attributes["value"] = value; + const_op.type = TFTypeUint32; + return const_op; +} + +Op& constant(float value) { + Op& const_op = func_op("const"); + const_op.attributes["value"] = value; + const_op.type = TFTypeFloat32; + return const_op; +} + +Op& constant(bool value) { + Op& const_op = func_op("const"); + const_op.attributes["value"] = value; + const_op.type = TFTypeBool32; + return const_op; +} + +// Arithmetic operations +Op& operator+(const Op& a, const Op& b) { return func_op("add", a, b); } +Op& operator-(const Op& a, const Op& b) { return func_op("sub", a, b); } +Op& operator*(const Op& a, const Op& b) { return func_op("mul", a, b); } +Op& operator/(const Op& a, const Op& b) { return func_op("div", a, b); } +Op& operator%(const Op& a, const Op& b) { return func_op("mod", a, b); } + +// Bitwise operations +Op& operator&(const Op& a, const Op& b) { return func_op("and", a, b); } +Op& operator|(const Op& a, const Op& b) { return func_op("or", a, b); } +Op& operator^(const Op& a, const Op& b) { return func_op("xor", a, b); } +Op& operator<<(const Op& a, const Op& b) { return func_op("lshift", a, b); } +Op& operator>>(const Op& a, const Op& b) { return func_op("rshift", a, b); } +Op& operator~(const Op& a) { return func_op("not", a); } + +// Comparison operations +Op& operator==(const Op& a, const Op& b) { return func_op("eq", a, b); } +Op& operator!=(const Op& a, const Op& b) { return func_op("neq", a, b); } +Op& operator<(const Op& a, const Op& b) { return func_op("lt", a, b); } +Op& operator<=(const Op& a, const Op& b) { return func_op("lte", a, b); } +Op& operator>(const Op& a, const Op& b) { return func_op("gt", a, b); } +Op& operator>=(const Op& a, const Op& b) { return func_op("gte", a, b); } + +// Logical operations +Op& operator&&(const Op& a, const Op& b) { return func_op("land", a, b); } +Op& operator||(const Op& a, const Op& b) { return func_op("lor", a, b); } +Op& operator!(const Op& a) { return func_op("lnot", a); } + +// Assignment operations +// Op& operator+=(const Op& a, const Op& b) { return func_op("add_assign", a, b); } +// Op& operator-=(const Op& a, const Op& b) { return func_op("sub_assign", a, b); } +// Op& operator*=(const Op& a, const Op& b) { return func_op("mul_assign", a, b); } +// Op& operator/=(const Op& a, const Op& b) { return func_op("div_assign", a, b); } +// Op& operator%=(const Op& a, const Op& b) { return func_op("mod_assign", a, b); } +// Op& operator&=(const Op& a, const Op& b) { return func_op("and_assign", a, b); } +// Op& operator|=(const Op& a, const Op& b) { return func_op("or_assign", a, b); } +// Op& operator^=(const Op& a, const Op& b) { return func_op("xor_assign", a, b); } +// Op& operator<<=(const Op& a, const Op& b) { return func_op("lshift_assign", a, b); } +// Op& operator>>=(const Op& a, const Op& b) { return func_op("rshift_assign", a, b); } +// Op& operator++(const Op& a) { return a += 1; } +// Op& operator--(const Op& a) { return a -= 1; } + +Op& copy(const Op& a) { return func_op("copy", a); } +Op& sin(const Op& a) { return func_op("sin", a); } +Op& cos(const Op& a) { return func_op("cos", a); } +Op& tan(const Op& a) { return func_op("tan", a); } +Op& asin(const Op& a) { return func_op("asin", a); } +Op& acos(const Op& a) { return func_op("acos", a); } +Op& atan(const Op& a) { return func_op("atan", a); } +Op& sinh(const Op& a) { return func_op("sinh", a); } +Op& cosh(const Op& a) { return func_op("cosh", a); } +Op& tanh(const Op& a) { return func_op("tanh", a); } +Op& asinh(const Op& a) { return func_op("asinh", a); } +Op& acosh(const Op& a) { return func_op("acosh", a); } +Op& atanh(const Op& a) { return func_op("atanh", a); } +Op& exp(const Op& a) { return func_op("exp", a); } +Op& log(const Op& a) { return func_op("log", a); } +Op& log2(const Op& a) { return func_op("log2", a); } +Op& exp2(const Op& a) { return func_op("exp2", a); } +Op& sqrt(const Op& a) { return func_op("sqrt", a); } +Op& sqr(const Op& a) { return func_op("sqr", a); } +Op& rsqrt(const Op& a) { return func_op("rsqrt", a); } +Op& rcp(const Op& a) { return func_op("rcp", a); } +Op& abs(const Op& a) { return func_op("abs", a); } +Op& sign(const Op& a) { return func_op("sign", a); } +Op& floor(const Op& a) { return func_op("floor", a); } +Op& ceil(const Op& a) { return func_op("ceil", a); } +Op& round(const Op& a) { return func_op("round", a); } +Op& trunc(const Op& a) { return func_op("trunc", a); } +Op& frac(const Op& a) { return func_op("frac", a); } +Op& pcg(const Op& a) { return func_op("pcg", a); } +Op& pcgf(const Op& a) { return func_op("pcgf", a); } +Op& reversebits(const Op& a) { return func_op("reversebits", a); } + +Op& clamp(const Op& x, const Op& min, const Op& max) { return func_op("clamp", x, min, max); } +Op& pow(const Op& x, const Op& y) { return func_op("pow", x, y); } +Op& min(const Op& x, const Op& y) { return func_op("min", x, y); } +Op& max(const Op& x, const Op& y) { return func_op("max", x, y); } +Op& mod(const Op& x, const Op& y) { return func_op("mod", x, y); } +Op& modf(const Op& x, const Op& y) { return func_op("modf", x, y); } +Op& atan2(const Op& x, const Op& y) { return func_op("atan2", x, y); } +Op& grad(const Op& x, const Op& wrt) { return func_op("backwards_grad", x, wrt); } +Op& lerp(const Op& x, const Op& y, const Op& a) { return func_op("lerp", x, y, a); } +Op& smoothstep(const Op& a, const Op& b, const Op& x) { return func_op("smoothstep", a, b, x); } +Op& select(Op& cond, const Op& x, const Op& y) { return func_op("ternary", cond, x, y); } +Op& fma(const Op& x, const Op& y, const Op& z) { return func_op("fma", x, y, z); } + +// Type conversion operations +Op& tofloat(const Op& a) { return func_op("tofloat", a); } +Op& toint(const Op& a) { return func_op("toint", a); } +Op& touint(const Op& a) { return func_op("touint", a); } +Op& tobool(const Op& a) { return func_op("tobool", a); } + +Op& asfloat(const Op& a) { return func_op("asfloat", a); } +Op& asint(const Op& a) { return func_op("asint", a); } +Op& asuint(const Op& a) { return func_op("asuint", a); } \ No newline at end of file From 3713ac325e3a16c2b0c468f01440a17c74bfa45a Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz Date: Sun, 1 Jun 2025 05:17:39 +0200 Subject: [PATCH 03/44] NUKE EVERYTHING --- TensorFrost/Backend/Backend.cpp | 187 --- TensorFrost/Backend/Backend.h | 49 - TensorFrost/Backend/Backends/CPU/CPU.h | 5 - .../Backend/Backends/CPU/KernelCompiler.cpp | 222 --- .../Backend/Backends/CPU/KernelCompiler.h | 37 - .../Backend/Backends/CPU/KernelManager.h | 41 - TensorFrost/Backend/Backends/CPU/Memory.h | 60 - .../Backend/Backends/OpenGL/KernelManager.h | 183 --- TensorFrost/Backend/Backends/OpenGL/Memory.h | 102 -- .../Backend/Backends/OpenGL/OpenGL.cpp | 384 ----- TensorFrost/Backend/Backends/OpenGL/OpenGL.h | 54 - TensorFrost/Backend/CodeGen/Generators.cpp | 556 ------- TensorFrost/Backend/CodeGen/Generators.h | 205 --- TensorFrost/Backend/CodeGen/Langs/CPP.cpp | 781 ---------- TensorFrost/Backend/CodeGen/Langs/GLSL.cpp | 188 --- TensorFrost/Backend/CodeGen/Langs/HLSL.cpp | 135 -- TensorFrost/Backend/CodeGen/Langs/Listing.cpp | 155 -- TensorFrost/Backend/KernelManager.cpp | 38 - TensorFrost/Backend/KernelManager.h | 34 - TensorFrost/Backend/RenderDoc.cpp | 56 - TensorFrost/Backend/RenderDoc.h | 8 - TensorFrost/Backend/TensorMemory.cpp | 236 --- TensorFrost/Backend/TensorMemory.h | 146 -- TensorFrost/CMakeLists.txt | 85 +- TensorFrost/Compiler/Graph/Arguments.cpp | 19 - TensorFrost/Compiler/Graph/Arguments.h | 221 --- TensorFrost/Compiler/Graph/IR.cpp | 226 --- TensorFrost/Compiler/Graph/IR.h | 413 ----- TensorFrost/Compiler/Graph/Node.cpp | 365 ----- TensorFrost/Compiler/Graph/Node.h | 289 ---- TensorFrost/Compiler/Graph/Scope.cpp | 325 ---- TensorFrost/Compiler/Graph/Scope.h | 173 --- TensorFrost/Compiler/Implementations.cpp | 783 ---------- TensorFrost/Compiler/Implementations.h | 112 -- TensorFrost/Compiler/KernelGen.cpp | 74 - TensorFrost/Compiler/KernelGen.h | 80 - TensorFrost/Compiler/Operations.cpp | 226 --- TensorFrost/Compiler/Operations.h | 284 ---- TensorFrost/Compiler/Steps/Algorithms.cpp | 64 - TensorFrost/Compiler/Steps/Autodiff.cpp | 154 -- TensorFrost/Compiler/Steps/GraphOps.cpp | 1343 ----------------- TensorFrost/Compiler/Steps/Optimization.cpp | 760 ---------- .../Frontend/Python/Definitions/PyModule.cpp | 62 - .../Frontend/Python/Definitions/PyTensor.cpp | 199 --- .../Python/Definitions/TensorFunctions.cpp | 353 ----- .../Python/Definitions/TensorMemory.cpp | 72 - .../Python/Definitions/TensorProgram.cpp | 199 --- .../Python/Definitions/TensorScope.cpp | 118 -- .../Python/Definitions/WindowUtils.cpp | 148 -- TensorFrost/Frontend/Python/PyModule.h | 430 ------ TensorFrost/Frontend/Python/PyTensor.cpp | 226 --- TensorFrost/Frontend/Python/PyTensor.h | 104 -- .../Frontend/Python/PyTensorMemory.cpp | 106 -- TensorFrost/Frontend/Python/PyTensorMemory.h | 60 - TensorFrost/Frontend/Python/PybindModule.cpp | 128 -- TensorFrost/IR/include/Overloads.h | 236 --- TensorFrost/IR/src/Overloads.cpp | 149 -- TensorFrost/PybindModule.cpp | 134 ++ TensorFrost/Tensor/Tensor.cpp | 357 ----- TensorFrost/Tensor/Tensor.h | 1228 --------------- TensorFrost/Tensor/TensorProgram.cpp | 89 -- TensorFrost/Tensor/TensorProgram.h | 46 - TensorFrost/TensorFrost.h | 10 - TensorFrost/Utility/Utility.cpp | 13 - TensorFrost/Utility/Utility.h | 132 -- .../{IR/include => include/Compiler}/Common.h | 13 +- .../Compiler}/ExecutionContext.h | 0 .../include => include/Compiler}/Operation.h | 5 + .../Compiler}/OperationArguments.h | 0 .../Compiler}/OperationBlocks.h | 0 .../Compiler}/OperationRegistry.h | 0 TensorFrost/include/Compiler/Overloads.h | 149 ++ .../{IR/src => src/Compiler}/Common.cpp | 7 +- .../src => src/Compiler}/ExecutionContext.cpp | 10 +- .../{IR/src => src/Compiler}/Operation.cpp | 8 +- .../Compiler}/OperationArguments.cpp | 7 +- .../src => src/Compiler}/OperationBlocks.cpp | 6 +- .../Compiler}/OperationRegistry.cpp | 30 +- TensorFrost/src/Compiler/Overloads.cpp | 51 + TensorFrost/src/Definitions/PyModule.cpp | 62 + TensorFrost/src/Definitions/PyTensor.cpp | 199 +++ .../src/Definitions/TensorFunctions.cpp | 353 +++++ TensorFrost/src/Definitions/TensorMemory.cpp | 72 + TensorFrost/src/Definitions/TensorProgram.cpp | 199 +++ TensorFrost/src/Definitions/TensorScope.cpp | 118 ++ TensorFrost/src/Definitions/WindowUtils.cpp | 148 ++ 86 files changed, 1585 insertions(+), 14309 deletions(-) delete mode 100644 TensorFrost/Backend/Backend.cpp delete mode 100644 TensorFrost/Backend/Backend.h delete mode 100644 TensorFrost/Backend/Backends/CPU/CPU.h delete mode 100644 TensorFrost/Backend/Backends/CPU/KernelCompiler.cpp delete mode 100644 TensorFrost/Backend/Backends/CPU/KernelCompiler.h delete mode 100644 TensorFrost/Backend/Backends/CPU/KernelManager.h delete mode 100644 TensorFrost/Backend/Backends/CPU/Memory.h delete mode 100644 TensorFrost/Backend/Backends/OpenGL/KernelManager.h delete mode 100644 TensorFrost/Backend/Backends/OpenGL/Memory.h delete mode 100644 TensorFrost/Backend/Backends/OpenGL/OpenGL.cpp delete mode 100644 TensorFrost/Backend/Backends/OpenGL/OpenGL.h delete mode 100644 TensorFrost/Backend/CodeGen/Generators.cpp delete mode 100644 TensorFrost/Backend/CodeGen/Generators.h delete mode 100644 TensorFrost/Backend/CodeGen/Langs/CPP.cpp delete mode 100644 TensorFrost/Backend/CodeGen/Langs/GLSL.cpp delete mode 100644 TensorFrost/Backend/CodeGen/Langs/HLSL.cpp delete mode 100644 TensorFrost/Backend/CodeGen/Langs/Listing.cpp delete mode 100644 TensorFrost/Backend/KernelManager.cpp delete mode 100644 TensorFrost/Backend/KernelManager.h delete mode 100644 TensorFrost/Backend/RenderDoc.cpp delete mode 100644 TensorFrost/Backend/RenderDoc.h delete mode 100644 TensorFrost/Backend/TensorMemory.cpp delete mode 100644 TensorFrost/Backend/TensorMemory.h delete mode 100644 TensorFrost/Compiler/Graph/Arguments.cpp delete mode 100644 TensorFrost/Compiler/Graph/Arguments.h delete mode 100644 TensorFrost/Compiler/Graph/IR.cpp delete mode 100644 TensorFrost/Compiler/Graph/IR.h delete mode 100644 TensorFrost/Compiler/Graph/Node.cpp delete mode 100644 TensorFrost/Compiler/Graph/Node.h delete mode 100644 TensorFrost/Compiler/Graph/Scope.cpp delete mode 100644 TensorFrost/Compiler/Graph/Scope.h delete mode 100644 TensorFrost/Compiler/Implementations.cpp delete mode 100644 TensorFrost/Compiler/Implementations.h delete mode 100644 TensorFrost/Compiler/KernelGen.cpp delete mode 100644 TensorFrost/Compiler/KernelGen.h delete mode 100644 TensorFrost/Compiler/Operations.cpp delete mode 100644 TensorFrost/Compiler/Operations.h delete mode 100644 TensorFrost/Compiler/Steps/Algorithms.cpp delete mode 100644 TensorFrost/Compiler/Steps/Autodiff.cpp delete mode 100644 TensorFrost/Compiler/Steps/GraphOps.cpp delete mode 100644 TensorFrost/Compiler/Steps/Optimization.cpp delete mode 100644 TensorFrost/Frontend/Python/Definitions/PyModule.cpp delete mode 100644 TensorFrost/Frontend/Python/Definitions/PyTensor.cpp delete mode 100644 TensorFrost/Frontend/Python/Definitions/TensorFunctions.cpp delete mode 100644 TensorFrost/Frontend/Python/Definitions/TensorMemory.cpp delete mode 100644 TensorFrost/Frontend/Python/Definitions/TensorProgram.cpp delete mode 100644 TensorFrost/Frontend/Python/Definitions/TensorScope.cpp delete mode 100644 TensorFrost/Frontend/Python/Definitions/WindowUtils.cpp delete mode 100644 TensorFrost/Frontend/Python/PyModule.h delete mode 100644 TensorFrost/Frontend/Python/PyTensor.cpp delete mode 100644 TensorFrost/Frontend/Python/PyTensor.h delete mode 100644 TensorFrost/Frontend/Python/PyTensorMemory.cpp delete mode 100644 TensorFrost/Frontend/Python/PyTensorMemory.h delete mode 100644 TensorFrost/Frontend/Python/PybindModule.cpp delete mode 100644 TensorFrost/IR/include/Overloads.h delete mode 100644 TensorFrost/IR/src/Overloads.cpp create mode 100644 TensorFrost/PybindModule.cpp delete mode 100644 TensorFrost/Tensor/Tensor.cpp delete mode 100644 TensorFrost/Tensor/Tensor.h delete mode 100644 TensorFrost/Tensor/TensorProgram.cpp delete mode 100644 TensorFrost/Tensor/TensorProgram.h delete mode 100644 TensorFrost/TensorFrost.h delete mode 100644 TensorFrost/Utility/Utility.cpp delete mode 100644 TensorFrost/Utility/Utility.h rename TensorFrost/{IR/include => include/Compiler}/Common.h (89%) rename TensorFrost/{IR/include => include/Compiler}/ExecutionContext.h (100%) rename TensorFrost/{IR/include => include/Compiler}/Operation.h (72%) rename TensorFrost/{IR/include => include/Compiler}/OperationArguments.h (100%) rename TensorFrost/{IR/include => include/Compiler}/OperationBlocks.h (100%) rename TensorFrost/{IR/include => include/Compiler}/OperationRegistry.h (100%) create mode 100644 TensorFrost/include/Compiler/Overloads.h rename TensorFrost/{IR/src => src/Compiler}/Common.cpp (89%) rename TensorFrost/{IR/src => src/Compiler}/ExecutionContext.cpp (88%) rename TensorFrost/{IR/src => src/Compiler}/Operation.cpp (79%) rename TensorFrost/{IR/src => src/Compiler}/OperationArguments.cpp (92%) rename TensorFrost/{IR/src => src/Compiler}/OperationBlocks.cpp (97%) rename TensorFrost/{IR/src => src/Compiler}/OperationRegistry.cpp (78%) create mode 100644 TensorFrost/src/Compiler/Overloads.cpp create mode 100644 TensorFrost/src/Definitions/PyModule.cpp create mode 100644 TensorFrost/src/Definitions/PyTensor.cpp create mode 100644 TensorFrost/src/Definitions/TensorFunctions.cpp create mode 100644 TensorFrost/src/Definitions/TensorMemory.cpp create mode 100644 TensorFrost/src/Definitions/TensorProgram.cpp create mode 100644 TensorFrost/src/Definitions/TensorScope.cpp create mode 100644 TensorFrost/src/Definitions/WindowUtils.cpp diff --git a/TensorFrost/Backend/Backend.cpp b/TensorFrost/Backend/Backend.cpp deleted file mode 100644 index 8163a39e..00000000 --- a/TensorFrost/Backend/Backend.cpp +++ /dev/null @@ -1,187 +0,0 @@ -#include "Backend.h" - -namespace TensorFrost { - -BackendType current_backend = BackendType::NotInitialized; -CodeGenLang current_kernel_lang = CodeGenLang::CPP; -CodeGenLang current_main_lang = CodeGenLang::CPP; -bool strip_debug_names = false; - -void InitializeBackend(BackendType backendType, const string& compilerOptions, CodeGenLang kernelType) { - if (current_backend != BackendType::NotInitialized) { - cout << "Warning: Backend already initialized, stopping current backend\n" << endl; - - switch (current_backend) { - case BackendType::CPU: - break; - case BackendType::Vulkan: - break; - case BackendType::OpenGL: - StopOpenGL(); - break; - default: - throw std::runtime_error("Backend not implemented"); - } - } - - if (!compilerOptions.empty()) { - kernelCompileOptions = compilerOptions; - } else if(backendType != BackendType::CPU) { - kernelCompileOptions = ""; //no need for cpu optimizations on other backends - } else { -#ifdef _WIN32 - kernelCompileOptions = "/O2 /fp:fast /openmp"; -#else - kernelCompileOptions = "-O3 -ffast-math -fopenmp"; -#endif - } - -#ifdef _DEBUG -#ifdef _WIN32 - kernelCompileOptions = "/Zi"; -#else - kernelCompileOptions = "-g"; -#endif -#endif - - current_backend = backendType; - - switch (backendType) { - case BackendType::CPU: - case BackendType::CodeGen: - current_kernel_lang = CodeGenLang::CPP; - global_memory_manager = new CpuMemoryManager(); - global_kernel_manager = new CpuKernelManager(); - break; - case BackendType::Vulkan: - throw std::runtime_error("Vulkan backend not implemented yet"); - current_kernel_lang = CodeGenLang::GLSL; - break; - case BackendType::OpenGL: - StartOpenGL(); - current_kernel_lang = CodeGenLang::GLSL; - global_memory_manager = new OpenGLMemoryManager(); - global_kernel_manager = new OpenGLKernelManager(); - break; - default: - throw std::runtime_error("Backend not implemented"); - } - - if (kernelType != CodeGenLang::None) { - current_kernel_lang = kernelType; - } -} - -void CompileKernels(Program* program) { - auto start_time = chrono::high_resolution_clock::now(); - for(auto& kernel : program->kernels_) { - switch (current_backend) { - case BackendType::CPU: - //already in the host program - break; - case BackendType::Vulkan: - throw std::runtime_error("Vulkan backend not implemented yet"); - case BackendType::OpenGL: - ((OpenGLKernelManager*)global_kernel_manager)->CompileKernel(&kernel); - break; - default: - throw std::runtime_error("Backend not implemented"); - } - } - auto end_time = chrono::high_resolution_clock::now(); - float milliseconds = chrono::duration(end_time - start_time).count(); - program->shader_compile_time = milliseconds; -} - -TFTensor Allocator(const char* name, const size_t* a, size_t dim, TFDataFormat format, void* data) { - try { - vector shape(a, a + dim); - return *global_memory_manager->AllocateTensor(shape, format, name); - } catch (const std::exception& e) { - size_t size = 1; - for (size_t i = 0; i < dim; i++) { - size *= a[i]; - } - throw std::runtime_error("Error allocating tensor " + string(name) + ": " + e.what() + ", requested size: " + to_string(size)); - } -} - -void Deallocator(TFTensor a, void* data) { - global_memory_manager->DeallocateTensor(a); -} - -uint Readback(TFTensor a, size_t index, void* data) { - return global_memory_manager->ReadbackValue(&a, index); -} - -void Writeback(TFTensor a, size_t index, uint32_t value, void* data) { - global_memory_manager->WritebackValue(&a, index, value); -} - -void Dispatch(TFDispatchInfo info, void* data) { - global_kernel_manager->DispatchKernel(info); -} - -void Region(const char* name, bool begin, void* data) { - if (current_backend == BackendType::OpenGL) { - if (begin) { - StartDebugRegion(name); - } else { - EndDebugRegion(); - } - } -} - -//#define PROFILE_EXECUTION - -vector ExecuteProgram( - Program* program, vector inputs) { - - if (current_backend == BackendType::CodeGen) { - throw std::runtime_error("Cannot execute program with code generation backend"); - } - - int memory_input_count = (int)program->ir_->input_memory_map.size(); - - if (memory_input_count != inputs.size()) { - throw std::runtime_error( - "Invalid number of inputs for TensorProgram. Expected " + - to_string(memory_input_count) + ", got " + to_string(inputs.size())); - } - - vector input_tensors; - for (int i = 0; i < memory_input_count; i++) { - input_tensors.push_back(*inputs[i]); - } - - unordered_map output_memory_map = program->ir_->output_memory_map; - int output_count = (int)output_memory_map.size(); - - TFTensor* in = input_tensors.data(); - TFTensor* out = new TFTensor[output_count]; - -#ifdef PROFILE_EXECUTION - auto start = chrono::high_resolution_clock::now(); -#endif - try { - program->execute_callback(in, out, {Allocator, Deallocator, Readback, Writeback, Dispatch, Region, nullptr}); - } catch (const std::exception& e) { - throw std::runtime_error("Error executing program " + program->program_name + ": " + e.what()); - } - -#ifdef PROFILE_EXECUTION - Finish(); - auto end = chrono::high_resolution_clock::now(); - float milliseconds = chrono::duration(end - start).count(); - program->last_execution_time = milliseconds; -#endif - - vector outputs = vector(output_count); - for (int i = 0; i < output_count; i++) { - outputs[i] = &out[i]; - } - - return outputs; -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/Backend.h b/TensorFrost/Backend/Backend.h deleted file mode 100644 index dd4151f9..00000000 --- a/TensorFrost/Backend/Backend.h +++ /dev/null @@ -1,49 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "Backends/CPU/CPU.h" -#include "Backends/OpenGL/OpenGL.h" -#include "CodeGen/Generators.h" -#include "KernelManager.h" -#include "TensorMemory.h" -#include "RenderDoc.h" - -namespace TensorFrost { - -using namespace std; - -enum class BackendType { - CPU, - Vulkan, - OpenGL, - CodeGen, - NotInitialized -}; - -enum class CodeGenLang { - CPP, - HLSL, - GLSL, - None, -}; - -extern BackendType current_backend; -extern CodeGenLang current_kernel_lang; -extern CodeGenLang current_main_lang; -extern bool strip_debug_names; - -vector ExecuteProgram( - Program* program, vector inputs); - -void InitializeBackend(BackendType backendType, const string& compilerPath, CodeGenLang kernelType); - -void CompileKernels(Program* program); - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/Backends/CPU/CPU.h b/TensorFrost/Backend/Backends/CPU/CPU.h deleted file mode 100644 index c166ba59..00000000 --- a/TensorFrost/Backend/Backends/CPU/CPU.h +++ /dev/null @@ -1,5 +0,0 @@ -#pragma once - -#include "KernelCompiler.h" -#include "Memory.h" -#include "KernelManager.h" \ No newline at end of file diff --git a/TensorFrost/Backend/Backends/CPU/KernelCompiler.cpp b/TensorFrost/Backend/Backends/CPU/KernelCompiler.cpp deleted file mode 100644 index 6a220b0f..00000000 --- a/TensorFrost/Backend/Backends/CPU/KernelCompiler.cpp +++ /dev/null @@ -1,222 +0,0 @@ -#include "KernelCompiler.h" - -#include - -namespace TensorFrost { - -std::string kernelCompileOptions; - -//TODO: Add support for MacOS -#ifndef __APPLE__ -bool RunCompiler(char* tempPath, char* dllName, const char* sourcePath) { - std::basic_stringstream ss; - std::string output; - -#if defined(_WIN32) - ss << "powershell -command \"$VisualStudioPath = & \\\"${Env:ProgramFiles(x86)}\\Microsoft Visual Studio\\Installer\\vswhere.exe\\\" -latest -products * -property installationPath; & cmd.exe /C \\\"\"\\\"\\\"$VisualStudioPath\\VC\\Auxiliary\\Build\\vcvarsall.bat\\\"\\\" x64 && cl " - << kernelCompileOptions << " /LD " << tempPath - << sourcePath << " /Fe:" << dllName - << "\"\"\\\"\""; -#else - ss << "g++ " << kernelCompileOptions << " -shared -fPIC " << tempPath - << sourcePath << " -o " << dllName; -#endif - - std::basic_string command = ss.str(); - - // Run the compiler -#if defined(_WIN32) - SECURITY_ATTRIBUTES sa; - sa.nLength = sizeof(SECURITY_ATTRIBUTES); - sa.bInheritHandle = TRUE; - sa.lpSecurityDescriptor = NULL; - - HANDLE hReadPipe, hWritePipe; - if (!CreatePipe(&hReadPipe, &hWritePipe, &sa, 0)) { - throw std::runtime_error("Failed to create pipe"); - } - - STARTUPINFO si; - PROCESS_INFORMATION pi; - ZeroMemory(&si, sizeof(si)); - si.cb = sizeof(si); - si.hStdError = hWritePipe; - si.hStdOutput = hWritePipe; - si.dwFlags |= STARTF_USESTDHANDLES; - - if (!CreateProcess(nullptr, command.data(), nullptr, nullptr, TRUE, 0, nullptr, nullptr, &si, &pi)) { - throw std::runtime_error(std::string("Steps error: cannot create compiler process. Command line: ") + command.data() + "\n"); - } - - CloseHandle(hWritePipe); - - char buffer[4096]; - DWORD bytesRead; - while (ReadFile(hReadPipe, buffer, sizeof(buffer), &bytesRead, NULL) && bytesRead != 0) { - output.append(buffer, bytesRead); - } - - WaitForSingleObject(pi.hProcess, INFINITE); - - DWORD exit_code; - GetExitCodeProcess(pi.hProcess, &exit_code); - if (exit_code != 0) { - throw std::runtime_error( - "Steps error: compiler exited with non-zero exit code (Error " - "code: " + std::to_string(exit_code) + ")\nCompiler output:\n" + output); - } - - CloseHandle(pi.hProcess); - CloseHandle(pi.hThread); - CloseHandle(hReadPipe); -#else - FILE* pipe = popen(command.c_str(), "r"); - if (!pipe) { - throw std::runtime_error("popen() failed!"); - } - - char buffer[128]; - while (fgets(buffer, sizeof(buffer), pipe) != nullptr) { - output += buffer; - } - - int status = pclose(pipe); - if (status != 0) { - throw std::runtime_error( - "Steps error: compiler exited with non-zero exit code (Error " - "code: " + std::to_string(status) + ")\nCompiler output:\n" + output); - } -#endif - - return true; -} - -void CompileKernelLibrary(const string& sourceCode, char* tempPath, - char* dllName, size_t program_id) { - // Append a file name to the tempPath - std::string source_name = "generated_lib_" + std::to_string(program_id) + ".cpp"; - std::basic_stringstream ss; - ss << tempPath << source_name; - std::basic_string full_file_path = ss.str(); - - const std::string& file_path(full_file_path); - - cout << "Source path: " << file_path << endl; - - // Write the generated source code to a file - std::ofstream out_file(file_path); - if (!out_file) { - throw std::runtime_error( - "Steps error: cannot open file for writing generated source code"); - } - out_file << sourceCode; - out_file.close(); - - RunCompiler(tempPath, dllName, source_name.c_str()); -} -#endif - -void CompileAndLoadKernelModule(Program* program, size_t program_id) { -#ifndef __APPLE__ -#if defined(_WIN32) - char temp_path[MAX_PATH]; - DWORD path_length = GetTempPath(MAX_PATH, temp_path); - - if (path_length == 0) { - throw std::runtime_error("Steps error: cannot get temp path"); - } - - // Create a temporary library name - char temp_file_name[MAX_PATH]; - if (!GetTempFileName(temp_path, TEXT("lib"), 0, temp_file_name)) { - throw std::runtime_error("Steps error: cannot create temp file"); - } -#else - char temp_path[] = "/tmp/"; - char filename_template[] = "/tmp/tensorfrost_XXXXXX"; - char* temp_file_name = mktemp(filename_template); - if (!temp_file_name) { - throw std::runtime_error("Steps error: cannot create temp file"); - } -#endif - - cout << "Temp file: " << temp_file_name << endl; - - // Compile the library - CompileKernelLibrary(program->generated_code_, temp_path, temp_file_name, program_id); - - // Load the library - #if defined(_WIN32) - HMODULE lib_handle = LoadLibrary(temp_file_name); - if (!lib_handle) { - throw std::runtime_error("Steps error: cannot load generated library"); - } - #else - void* lib_handle = dlopen(temp_file_name, RTLD_LAZY); - if (!lib_handle) { - throw std::runtime_error("Steps error: cannot load generated library"); - } - #endif - - // Create lambda function to free the library - program->unload_callback = [lib_handle]() { - #if defined(_WIN32) - if (!FreeLibrary(lib_handle)) { - std::cerr << "Cannot free library: " << GetLastError() << '\n'; - } - #else - if (dlclose(lib_handle)) { - std::cerr << "Cannot free library: " << dlerror() << '\n'; - } - #endif - }; - - // Load the main function - #if defined(_WIN32) - auto main_callback = reinterpret_cast( - GetProcAddress(lib_handle, "main")); - #else - auto main_callback = reinterpret_cast( - dlsym(lib_handle, "main")); - #endif - - if (!main_callback) { - throw std::runtime_error("Steps error: cannot load main function"); - } - - // Set the execute callback - program->execute_callback = *main_callback; - - // load cpu kernel functions - if (current_backend == BackendType::CPU) - { - for (auto& kernel : program->kernels_) { - #if defined(_WIN32) - auto kernel_callback = reinterpret_cast( - GetProcAddress(lib_handle, kernel.kernel_name_.c_str())); - #else - auto kernel_callback = reinterpret_cast( - dlsym(lib_handle, kernel.kernel_name_.c_str())); - #endif - - if (!kernel_callback) { - throw std::runtime_error("Steps error: cannot load kernel function"); - } - - ((CpuKernelManager*)global_kernel_manager) - ->AddKernelFunction(&kernel, kernel_callback); - -#ifndef NDEBUG - cout << "Loaded kernel: " << kernel.kernel_name_ << endl; -#endif - } - } - - cout << "Successfully compiled and loaded kernel library." << endl; - -#else - throw std::runtime_error("Steps error: cannot compile and load kernel module on macOS"); -#endif -} - -} // namespace TensorFrost diff --git a/TensorFrost/Backend/Backends/CPU/KernelCompiler.h b/TensorFrost/Backend/Backends/CPU/KernelCompiler.h deleted file mode 100644 index 3f397f99..00000000 --- a/TensorFrost/Backend/Backends/CPU/KernelCompiler.h +++ /dev/null @@ -1,37 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#ifndef NOMINMAX -#define NOMINMAX -#endif - -#if defined(_WIN32) -#include -#else -#include -#include -#include -#endif - -#include "Backend/Backends/CPU/Memory.h" -#include "Backend/CodeGen/Generators.h" -#include "Backend/KernelManager.h" -#include "Backend/TensorMemory.h" -#include "Compiler/KernelGen.h" - -namespace TensorFrost { - -using namespace std; - -extern std::string kernelCompileOptions; - -void CompileAndLoadKernelModule(Program* program, size_t program_id); - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/Backends/CPU/KernelManager.h b/TensorFrost/Backend/Backends/CPU/KernelManager.h deleted file mode 100644 index eb03b1d4..00000000 --- a/TensorFrost/Backend/Backends/CPU/KernelManager.h +++ /dev/null @@ -1,41 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "../../KernelManager.h" - -namespace TensorFrost { - -class CpuKernelManager : public KernelManager { - unordered_map kernel_functions; - public: - - void AddKernelFunction(Kernel* kernel, cpu_dispatch_func* func) { - kernel_functions[kernel->kernel_id_] = func; - } - - cpu_dispatch_func* GetKernel(size_t id) { - return kernel_functions[id]; - } - - void DispatchKernel(TFDispatchInfo info) override - { - cpu_dispatch_func* func = kernel_functions[info.kernel_id]; - //get memory pointers - uint32_t** memory = new uint32_t*[info.read_write_count]; - for (size_t i = 0; i < info.read_write_count; i++) { - memory[i] = ((TFCPUBuffer*)info.read_write_tensors[i].buffer)->GetNative(); - } - func(info.variables, memory, (uint)info.work_group_count); - delete[] memory; - } -}; - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/Backends/CPU/Memory.h b/TensorFrost/Backend/Backends/CPU/Memory.h deleted file mode 100644 index cbc2375a..00000000 --- a/TensorFrost/Backend/Backends/CPU/Memory.h +++ /dev/null @@ -1,60 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "../../TensorMemory.h" - -namespace TensorFrost { - -using namespace std; - -class TFCPUBuffer: public TFBufferTemplate { -public: - uint32_t* data; - - TFCPUBuffer(size_t size): TFBufferTemplate(size) { - data = new uint32_t[size]; - } - - void UpdateName(const char* new_name) override { - if(new_name != nullptr) { - name = new_name; - } - } - - void SetDataAtOffset(size_t offset, const vector& data) override { - memcpy(this->data + offset, data.data(), data.size() * sizeof(uint32_t)); - } - - void GetDataAtOffset(size_t offset, size_t size, uint32_t* data) override { - memcpy(data, this->data + offset, size * sizeof(uint32_t)); - } - - uint32_t* GetNative() const { - return data; - } - - ~TFCPUBuffer() { - delete[] data; - } -}; - -class CpuMemoryManager : public TensorMemoryManager { - public: - TFBuffer* CreateBuffer(size_t size) override { - return new TFCPUBuffer(size); - } - - void DeleteBuffer(TFBuffer* buffer) override { - delete (TFCPUBuffer*)buffer; - } -}; - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/Backends/OpenGL/KernelManager.h b/TensorFrost/Backend/Backends/OpenGL/KernelManager.h deleted file mode 100644 index a9d5929d..00000000 --- a/TensorFrost/Backend/Backends/OpenGL/KernelManager.h +++ /dev/null @@ -1,183 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "../../KernelManager.h" - -namespace TensorFrost { - -class OpenGLKernelManager : public KernelManager { - unordered_map kernel_map; - const int WORK_GROUP_SIZE = 256; - GLuint ubo; - public: - OpenGLKernelManager() { - glGenBuffers(1, &ubo); - //allocate sizeof(uint32_t) * 32 bytes - glBindBuffer(GL_UNIFORM_BUFFER, ubo); - glBufferData(GL_UNIFORM_BUFFER, sizeof(uint32_t) * 32, nullptr, GL_DYNAMIC_DRAW); - glBindBuffer(GL_UNIFORM_BUFFER, 0); - } - - void UpdateUBO(const uint32_t* data, size_t size) { - glBindBuffer(GL_UNIFORM_BUFFER, ubo); - glBufferSubData(GL_UNIFORM_BUFFER, 0, sizeof(uint32_t) * size, data); - glBindBuffer(GL_UNIFORM_BUFFER, 0); - } - - GLuint createComputeShader(const std::string& source) { - GLuint shader = glCreateShader(GL_COMPUTE_SHADER); - const char* src = source.c_str(); - glShaderSource(shader, 1, &src, nullptr); - glCompileShader(shader); - - // Check for compilation errors - GLint success; - glGetShaderiv(shader, GL_COMPILE_STATUS, &success); - if (!success) { - GLchar infoLog[512]; - glGetShaderInfoLog(shader, 512, nullptr, infoLog); - throw std::runtime_error("TensorFrost: Error compiling shader: " + source + "\n" + std::string(infoLog)); - } - - return shader; - } - - GLuint createShaderProgram(const std::string& computeShaderSource) { - GLuint computeShader = createComputeShader(computeShaderSource); - GLuint program = glCreateProgram(); - glAttachShader(program, computeShader); - glLinkProgram(program); - - // Check for linking errors - GLint success; - glGetProgramiv(program, GL_LINK_STATUS, &success); - if (!success) { - GLchar infoLog[512]; - glGetProgramInfoLog(program, 512, nullptr, infoLog); - throw std::runtime_error("Error linking program: " + std::string(infoLog)); - } - - glDeleteShader(computeShader); - return program; - } - - void CompileKernel(Kernel* kernel) - { - #ifndef NDEBUG //print out source if debug is enabled - cout << "Compiling kernel \n" << kernel->full_generated_code_ << endl; - #endif - try { - GLuint program = createShaderProgram(kernel->full_generated_code_); - kernel_map[kernel->kernel_id_] = program; - } catch (const std::exception& e) { - string error_message = "Error compiling kernel " + to_string(kernel->kernel_id_) + "\n" + e.what(); - error_message = error_message + "\n" + kernel->full_generated_code_; - throw std::runtime_error(error_message); - } - } - - //Get uniform location - GLint getUniformLocation(GLuint program, const std::string& name) { - GLint location = glGetUniformLocation(program, name.c_str()); - if (location == -1) { - throw std::runtime_error("OpenGL error: uniform " + name + " not found"); - } - return location; - } - - //Get attribute location - GLint getAttribLocation(GLuint program, const std::string& name) { - GLint location = glGetAttribLocation(program, name.c_str()); - if (location == -1) { - throw std::runtime_error("OpenGL error: attribute " + name + " not found"); - } - return location; - } - - void DispatchKernel(TFDispatchInfo info) override - { - GLuint program = kernel_map[info.kernel_id]; - Kernel* kernel = GetKernel(info.kernel_id); - glUseProgram(program); - - #ifndef NDEBUG - // validate the program - glValidateProgram(program); - GLint success; - glGetProgramiv(program, GL_VALIDATE_STATUS, &success); - if (!success) { - GLchar infoLog[512]; - glGetProgramInfoLog(program, 512, nullptr, infoLog); - throw std::runtime_error("OpenGL error: program validation failed: " + std::string(infoLog)); - } - #endif - - // Set uniforms - if (info.read_write_count == 0) throw std::runtime_error("No tensors provided to kernel"); - - //bind all memory buffers - for (size_t i = 0; i < info.read_write_count; i++) { - GLuint buffer = ((TFOpenGLBuffer*)info.read_write_tensors[i].buffer)->GetNative(); - glBindBufferBase(GL_SHADER_STORAGE_BUFFER, (GLuint)i, buffer); - } - - if (info.variable_count > 0) - { - UpdateUBO(info.variables, info.variable_count); - } - - // Bind the UBO - glBindBufferBase(GL_UNIFORM_BUFFER, 0, ubo); - - // Dispatch the kernel - glDispatchCompute((GLuint)info.work_group_count, 1, 1); - - // Wait for the kernel to finish - glMemoryBarrier(GL_SHADER_STORAGE_BARRIER_BIT); - - // Unbind the memory buffers - for (size_t i = 0; i < info.read_write_count; i++) { - glBindBufferBase(GL_SHADER_STORAGE_BUFFER, (GLuint)i, 0); - } - - // Unbind the program - glUseProgram(0); - - // Check for errors - GLenum error = glGetError(); - if (error != GL_NO_ERROR) { - throw std::runtime_error("OpenGL error: " + std::to_string(error)); - } - } - - void FreeKernel(int kernel_id) - { - GLuint program = kernel_map[kernel_id]; - glDeleteProgram(program); - kernel_map.erase(kernel_id); - } - - void FreeAllKernels() - { - for (auto& kernel : kernel_map) { - glDeleteProgram(kernel.second); - } - kernel_map.clear(); - } - - ~OpenGLKernelManager() - { - FreeAllKernels(); - glDeleteBuffers(1, &ubo); - } -}; - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/Backends/OpenGL/Memory.h b/TensorFrost/Backend/Backends/OpenGL/Memory.h deleted file mode 100644 index 07103023..00000000 --- a/TensorFrost/Backend/Backends/OpenGL/Memory.h +++ /dev/null @@ -1,102 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "../../TensorMemory.h" - -namespace TensorFrost { - -class TFOpenGLBuffer: public TFBufferTemplate { - GLuint buffer; - - const size_t max_cache_size = 16384; - uint32_t* cached_data = nullptr; - - public: - TFOpenGLBuffer(size_t size): TFBufferTemplate(size) { - GLint maxsize; - glGetIntegerv(GL_MAX_SHADER_STORAGE_BLOCK_SIZE, &maxsize); - - if (size * sizeof(uint32_t) > maxsize) { - throw std::runtime_error("SSBO memory size exceeded, max size is " + std::to_string(maxsize)); - } - - glCreateBuffers(1, &buffer); - glNamedBufferStorage(buffer, size * sizeof(uint32_t), nullptr, GL_DYNAMIC_STORAGE_BIT); - } - - void UpdateName(const char* new_name) override { - if(new_name != nullptr && strcmp(new_name, "") != 0) { - name = new_name; - glObjectLabel(GL_BUFFER, buffer, -1, name); - } - } - - void UpdateCache(size_t data_offset, size_t data_size, const uint32_t* data) { - return; - if(data_offset == 0 && data_size == used_size && data_size <= max_cache_size) { - if(cached_data == nullptr) { - cached_data = new uint32_t[data_size]; - } - memcpy(cached_data, data, data_size * sizeof(uint32_t)); - up_to_date = true; - } - } - - void SetDataAtOffset(size_t offset, const vector& data) override { - glBindBuffer(GL_SHADER_STORAGE_BUFFER, buffer); - glBufferSubData(GL_SHADER_STORAGE_BUFFER, offset * sizeof(uint32_t), - data.size() * sizeof(uint32_t), data.data()); - glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); - - UpdateCache(offset, data.size(), data.data()); - } - - void GetDataAtOffset(size_t offset, size_t size, uint32_t* data) override { - if(up_to_date) { - memcpy(data, cached_data + offset, size * sizeof(uint32_t)); - return; - } - - glBindBuffer(GL_SHADER_STORAGE_BUFFER, buffer); - glGetBufferSubData(GL_SHADER_STORAGE_BUFFER, offset * sizeof(uint32_t), - size * sizeof(uint32_t), data); - glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); - - UpdateCache(offset, size, data); - } - - GLuint GetNative() const { - return buffer; - } - - ~TFOpenGLBuffer() { - glDeleteBuffers(1, &buffer); - if(cached_data != nullptr) { - delete[] cached_data; - } - } -}; - -class OpenGLMemoryManager : public TensorMemoryManager { - public: - OpenGLMemoryManager() {} - - TFBuffer* CreateBuffer(size_t size) override { - return new TFOpenGLBuffer(size); - } - - void DeleteBuffer(TFBuffer* buffer) override { - delete (TFOpenGLBuffer*)buffer; - } -}; - - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/Backends/OpenGL/OpenGL.cpp b/TensorFrost/Backend/Backends/OpenGL/OpenGL.cpp deleted file mode 100644 index 9478b3cd..00000000 --- a/TensorFrost/Backend/Backends/OpenGL/OpenGL.cpp +++ /dev/null @@ -1,384 +0,0 @@ -#include "OpenGL.h" - -namespace TensorFrost { - -string vertex_shader = R"( -#version 430 core -out vec2 texCoords; - -void main() { - const vec2 pos[6] = vec2[](vec2(-1.0, -1.0), vec2(1.0, -1.0), vec2(1.0, 1.0), - vec2(-1.0, -1.0), vec2(1.0, 1.0), vec2(-1.0, 1.0)); - gl_Position = vec4(pos[gl_VertexID], 0.0, 1.0); - texCoords = pos[gl_VertexID] * 0.5 + 0.5; -} -)"; - -string fragment_shader = R"( -#version 430 core -in vec2 texCoords; -out vec4 FragColor; - -layout(std430, binding = 0) buffer memory { - uint mem[]; -}; - -uniform int offset; -uniform int width; -uniform int height; - -void main() { - ivec2 pixel = ivec2(texCoords.x * width, texCoords.y * height); - int pixel_offset = pixel.y * width + pixel.x; - int cur_offset = pixel_offset * 3 + offset; - float r = uintBitsToFloat(mem[cur_offset]); - float g = uintBitsToFloat(mem[cur_offset + 1]); - float b = uintBitsToFloat(mem[cur_offset + 2]); - FragColor = vec4(r, g, b, 1.0); -} -)"; - -GLuint CreateShader(const string& source, GLenum type) { - GLuint shader = glCreateShader(type); - const char* src = source.c_str(); - glShaderSource(shader, 1, &src, nullptr); - glCompileShader(shader); - - int success; - glGetShaderiv(shader, GL_COMPILE_STATUS, &success); - if (!success) { - char infoLog[512]; - glGetShaderInfoLog(shader, 512, nullptr, infoLog); - throw std::runtime_error("Failed to compile shader: " + string(infoLog)); - } - - return shader; -} - -GLuint CreateProgram(const string& vertexSource, const string& fragmentSource) { - GLuint vertexShader = CreateShader(vertexSource, GL_VERTEX_SHADER); - GLuint fragmentShader = CreateShader(fragmentSource, GL_FRAGMENT_SHADER); - - GLuint program = glCreateProgram(); - glAttachShader(program, vertexShader); - glAttachShader(program, fragmentShader); - glLinkProgram(program); - - int success; - glGetProgramiv(program, GL_LINK_STATUS, &success); - - return program; -} - -GLuint quad_program = 0; -GLFWwindow* global_window = nullptr; -bool window_open = false; -ImGuiIO* io; - -void GLAPIENTRY DebugCallback(GLenum source, GLenum type, GLuint id, - GLenum severity, GLsizei length, - const GLchar* message, const void* userParam) { - // Output or log the debug message - std::cerr << "OpenGL Debug: source=" << source << ", type=" << type - << ", id=" << id << ", severity=" << severity << endl; - std::cerr << "Message: " << message << endl << endl; -} - - -// Window close callback function -void WindowCloseCallback(GLFWwindow* window) { - window_open = false; - - // Instead of closing, hide the window - glfwHideWindow(window); - - // Prevent the window from closing - glfwSetWindowShouldClose(window, GLFW_FALSE); -} - -void WindowSizeCallback(GLFWwindow* window, int width, int height) { - glViewport(0, 0, width, height); -} - -void ImguiNewFrame() { - if (global_window == nullptr) { - throw std::runtime_error("Window: OpenGL not initialized"); - } - - ImGui_ImplOpenGL3_NewFrame(); - ImGui_ImplGlfw_NewFrame(); - ImGui::NewFrame(); -} - -void ImguiRender() { - if (global_window == nullptr) { - throw std::runtime_error("Window: OpenGL not initialized"); - } - - ImGui::Render(); - ImGui_ImplOpenGL3_RenderDrawData(ImGui::GetDrawData()); -} - -std::string last_error; - -void error_callback(int error, const char* description) { - last_error = std::string(description); -} - -//#define OPENGL_DEBUG - -void StartOpenGL() { - glfwSetErrorCallback(error_callback); - - if (!glfwInit()) { - throw std::runtime_error("Failed to initialize GLFW: " + last_error); - } - - // Make window invisible - glfwWindowHint(GLFW_VISIBLE, GLFW_FALSE); - global_window = glfwCreateWindow(800, 600, "TensorFrost", nullptr, nullptr); - - if (!global_window) { - int code = glfwGetError(nullptr); - glfwTerminate(); - throw std::runtime_error("Failed to create window (error " + std::to_string(code) + "): " + last_error); - } - - glfwMakeContextCurrent(global_window); - - - int version = gladLoadGL(glfwGetProcAddress); - if (version == 0) { - throw std::runtime_error("Failed to load OpenGL"); - } - - // Successfully loaded OpenGL - printf("Loaded OpenGL %d.%d\n", GLAD_VERSION_MAJOR(version), - GLAD_VERSION_MINOR(version)); - - // Print the renderer string, which usually contains the GPU's name - const GLubyte* renderer = glGetString(GL_RENDERER); - const GLubyte* vendor = glGetString(GL_VENDOR); - printf("Renderer: %s\nVendor: %s\n", renderer, vendor); - - // Print the max buffer size - GLint bufferSize; - glGetIntegerv(GL_MAX_SHADER_STORAGE_BLOCK_SIZE, &bufferSize); - printf("Max available buffer size, MB: %d\n", bufferSize / 1024 / 1024); - - // Print max SSBO bindings - GLint max_ssbo_bindings; - glGetIntegerv(GL_MAX_SHADER_STORAGE_BUFFER_BINDINGS, &max_ssbo_bindings); - printf("Maximum SSBO bindings supported: %d\n", max_ssbo_bindings); - - // Enable debug output - #if !defined(NDEBUG) || defined(OPENGL_DEBUG) - if (GLAD_GL_KHR_debug) { - glEnable(GL_DEBUG_OUTPUT); - glEnable(GL_DEBUG_OUTPUT_SYNCHRONOUS); - glDebugMessageCallback(DebugCallback, nullptr); - glDebugMessageControl(GL_DONT_CARE, GL_DONT_CARE, GL_DEBUG_SEVERITY_HIGH, 0, nullptr, - GL_TRUE); - } - #endif - - // Create the shader program - quad_program = CreateProgram(vertex_shader, fragment_shader); - - glfwSetWindowCloseCallback(global_window, WindowCloseCallback); - glfwSetWindowSizeCallback(global_window, WindowSizeCallback); - - IMGUI_CHECKVERSION(); - ImGui::CreateContext(); - io = &ImGui::GetIO(); (void)*io; - io->ConfigFlags |= ImGuiConfigFlags_NavEnableKeyboard; - io->ConfigFlags |= ImGuiConfigFlags_NavEnableGamepad; - - // Setup Dear ImGui style - ImGui::StyleColorsDark(); - - ImGui_ImplGlfw_InitForOpenGL(global_window, true); - ImGui_ImplOpenGL3_Init("#version 430"); - - ImguiNewFrame(); -} - -void StopOpenGL() { - if (global_window == nullptr) { - throw std::runtime_error("OpenGL not initialized"); - } - - glfwDestroyWindow(global_window); - glfwTerminate(); - - glDeleteProgram(quad_program); - global_window = nullptr; - - // Cleanup ImGui - ImGui_ImplOpenGL3_Shutdown(); - ImGui_ImplGlfw_Shutdown(); - ImGui::DestroyContext(); -} - -void ShowWindow(int width, int height, const char* title) { - window_open = true; - - if(global_window == nullptr) { - throw std::runtime_error("Window: OpenGL not initialized"); - } - - glfwSetWindowSize(global_window, width, height); - glfwSetWindowTitle(global_window, title); - glfwShowWindow(global_window); - - //reset viewport - glViewport(0, 0, width, height); -} - -void HideWindow() { - if(global_window == nullptr) { - throw std::runtime_error("Window: OpenGL not initialized"); - } - - window_open = false; - glfwHideWindow(global_window); -} - -void Finish() { - glFinish(); -} - -void RenderFrame(const TFTensor* tensor) { - if (global_window == nullptr) { - throw std::runtime_error("RenderFrame: OpenGL not initialized"); - } - - // Clear the screen - glClearColor(0.0f, 0.0f, 0.0f, 1.0f); - glClear(GL_COLOR_BUFFER_BIT); - - if(tensor != nullptr) - { - //check if tensor is 2d + 3 channels - if (tensor->dim != 3 || tensor->shape[2] != 3) { - throw std::runtime_error("Window: Render tensor must be of shape (height, width, 3)"); - } - - //check if tensor is float32 (TODO: use int8 instead) - if (tensor->format != TFTypeFloat32) { - throw std::runtime_error("Window: Render tensor must be of type float32"); - } - - GLuint ssbo = ((TFOpenGLBuffer*)tensor->buffer)->GetNative(); - glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, ssbo); - - glUseProgram(quad_program); - - // Set the uniforms - int offset = 0; - int width = (int)tensor->shape[1]; - int height = (int)tensor->shape[0]; - glUniform1i(glGetUniformLocation(quad_program, "offset"), offset); - glUniform1i(glGetUniformLocation(quad_program, "width"), width); - glUniform1i(glGetUniformLocation(quad_program, "height"), height); - - // Draw the quad - glDrawArrays(GL_TRIANGLES, 0, 6); - - glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, 0); - glUseProgram(0); - } - - ImguiRender(); - - // Swap the buffers - glfwSwapBuffers(global_window); - glfwPollEvents(); - - ImguiNewFrame(); -} - -bool WindowShouldClose() { return !window_open; } - -pair GetMousePosition() { - double x, y; - glfwGetCursorPos(global_window, &x, &y); - return {x, y}; -} - -pair GetWindowSize() { - int width, height; - glfwGetWindowSize(global_window, &width, &height); - return {width, height}; -} - -bool IsMouseButtonPressed(int button) { - //if pressed in imgui, return false - if (io->WantCaptureMouse) { - return false; - } - return glfwGetMouseButton(global_window, button) == GLFW_PRESS; -} - -bool IsKeyPressed(int key) { - return glfwGetKey(global_window, key) == GLFW_PRESS; -} - -void ImGuiBegin(std::string name) { - ImGui::Begin(name.c_str()); -} - -void ImGuiEnd() { - ImGui::End(); -} - -void ImGuiText(const std::string& text) { - ImGui::Text("%s", text.c_str()); -} - -void ImGuiSlider(std::string text, int* value, int min, int max) { - ImGui::SliderInt(text.c_str(), value, min, max); -} - -void ImGuiSlider(std::string text, float* value, float min, float max) { - ImGui::SliderFloat(text.c_str(), value, min, max, "%.5f"); -} - -bool ImGuiButton(std::string text) { - return ImGui::Button(text.c_str()); -} - -void ImGuiPlotLines(const char *label, const float *values, int values_count, int values_offset, - const char *overlay_text, float scale_min, float scale_max, ImVec2 graph_size, int stride) { - ImGui::PlotLines(label, values, values_count, values_offset, overlay_text, scale_min, scale_max, graph_size, stride); -} - -bool ImGuiCheckbox(std::string text, bool* value) { - return ImGui::Checkbox(text.c_str(), value); -} - -void ImGuiScaleAllSizes(float scale) { - ImGuiStyle& style = ImGui::GetStyle(); - style.ScaleAllSizes(scale); -} - -void ImGuiAddBackgroundText(const std::string &text, const ImVec2 &pos, const ImVec4 &color) { - ImDrawList* draw_list = ImGui::GetBackgroundDrawList(); - draw_list->AddText(pos, ImGui::GetColorU32(color), text.c_str()); -} - -void ImGuiColorPicker3(const std::string &text, float *color) { - ImGui::ColorPicker3(text.c_str(), color); -} - -void ImGuiColorPicker4(const std::string &text, float *color) { - ImGui::ColorPicker4(text.c_str(), color); -} - -void StartDebugRegion(const std::string& name) { - glPushDebugGroup(GL_DEBUG_SOURCE_APPLICATION, 0, (GLsizei)name.size(), name.c_str()); -} - -void EndDebugRegion() { glPopDebugGroup(); } - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/Backends/OpenGL/OpenGL.h b/TensorFrost/Backend/Backends/OpenGL/OpenGL.h deleted file mode 100644 index 8f7c0561..00000000 --- a/TensorFrost/Backend/Backends/OpenGL/OpenGL.h +++ /dev/null @@ -1,54 +0,0 @@ -#pragma once - -#include "imgui.h" -#include "imgui_impl_glfw.h" -#include "imgui_impl_opengl3.h" - -#include "glad/gl.h" -#include "GLFW/glfw3.h" - -#include "Memory.h" -#include "KernelManager.h" - -namespace TensorFrost { - -void StartOpenGL(); - -void StopOpenGL(); - -void ShowWindow(int width, int height, const char* title); -void HideWindow(); - -void Finish(); - -void RenderFrame(const TFTensor* tensor); - -bool WindowShouldClose(); - -pair GetMousePosition(); -pair GetWindowSize(); - -bool IsMouseButtonPressed(int button); -bool IsKeyPressed(int key); - -void ImGuiBegin(std::string name); -void ImGuiEnd(); - -void ImGuiText(const std::string& text); -void ImGuiSlider(std::string text, int* value, int min, int max); -void ImGuiSlider(std::string text, float* value, float min, float max); -bool ImGuiCheckbox(std::string text, bool* value); -bool ImGuiButton(std::string text); - -void ImGuiPlotLines(const char* label, const float* values, int values_count, int values_offset = 0, const char* overlay_text = NULL, float scale_min = FLT_MAX, float scale_max = FLT_MAX, ImVec2 graph_size = ImVec2(0, 0), int stride = sizeof(float)); - -void ImGuiScaleAllSizes(float scale); - -void ImGuiAddBackgroundText(const std::string& text, const ImVec2& pos, const ImVec4& color); -void ImGuiColorPicker3(const std::string& text, float* color); -void ImGuiColorPicker4(const std::string& text, float* color); - -void StartDebugRegion(const std::string& name); -void EndDebugRegion(); - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/CodeGen/Generators.cpp b/TensorFrost/Backend/CodeGen/Generators.cpp deleted file mode 100644 index 62bdae9f..00000000 --- a/TensorFrost/Backend/CodeGen/Generators.cpp +++ /dev/null @@ -1,556 +0,0 @@ -#ifndef TENSORFROST_BACKEND_CODEGEN_GENERATORS_CPP -#define TENSORFROST_BACKEND_CODEGEN_GENERATORS_CPP - -#include "Backend/CodeGen/Generators.h" - -namespace TensorFrost { -using namespace std; - -void GenerateKernel(Program* program, Kernel* kernel) { - switch (current_kernel_lang) { - case CodeGenLang::CPP: - GenerateCPPKernel(program, kernel); - return; - case CodeGenLang::HLSL: - GenerateHLSLKernel(program, kernel); - return; - case CodeGenLang::GLSL: - GenerateGLSLKernel(program, kernel); - return; - default: - throw std::runtime_error("Code generation for this language is not implemented yet"); - } -} - -unordered_set forbidden_names = { - "unsigned", "input", "output", "max", "min", "exp", "sin", "cos", "if", "else", "while", "for", "switch", "case", "default", "break", - "this", "true", "false", "null", "new", "delete", "return", "continue", "goto", "try", "catch", "throw", - "const", "static", "extern", "inline", "virtual", "override", "final", "public", "protected", "private", "sample", - "texture", "sampler", "uniform", "varying", "attribute", "in", "out", "inout", "layout", "precision", "highp", - "mediump", "lowp", "noperspective", "flat", "smooth", "centroid", "patch", "sample", "subroutine", "common", - "partition", "active", "asm", "class", "union", "enum", "typedef", "template", "typename", "using", "namespace", - "friend", "sizeof", "alignof", "typeid", "dynamic_cast", "static_cast", "const_cast", "reinterpret_cast", "sizeof", - "alignof", "typeid", "noexcept", "throw", "auto", "register", "explicit", "mutable", "thread_local", - "constexpr", "decltype", "noexcept", "nullptr", "alignas", "and", "and_eq", "bitand", "bitor", "compl", "not", "not_eq", "or", - "sign", "xor", "xor_eq", "bool", "break", "case", "char" -}; - -bool IsForbiddenName(const string& name) { - return forbidden_names.contains(name); -} - -void GenerateNodeNames(const IR& ir) { - int var_index = 0; - int mem_index = 0; - int cluster_index = 0; - Node* curent_cluster = nullptr; - map name_count; - for (auto node = ir.begin(); !node.end(); node.next()) { - if (strip_debug_names && !node->debug_name.empty()) { - static int count = 0; - node->debug_name = "id_" + std::to_string(count++); - } - if (node->parent != curent_cluster) { - cluster_index++; - var_index = 0; - } - string debug = node->debug_name; - if (!debug.empty()) { - // check if the name is already used - if (name_count.contains(debug)) { - name_count[debug]++; - debug = debug + "_" + to_string(name_count[debug]); - } - else { - name_count[debug] = 1; - } - if (IsForbiddenName(debug) ) { - debug = debug + "0"; - } - node->var_name = debug; - } - else - { - if (node->name == "memory") { - node->var_name = debug + "m" + to_string(mem_index); - mem_index++; - } else { - node->var_name = - debug + "v" + to_string(cluster_index) + "_" + to_string(var_index); - var_index++; - } - } - - curent_cluster = node->parent; - } -} - - -string GetBufferDeclarations(Kernel *kernel, function get_name) { - map memory_bindings = kernel->GetMemoryBindings(); - - vector buffer_declarations = vector(memory_bindings.size()); - for (auto& buffer : memory_bindings) { - Node* mem_node = buffer.first; - size_t binding = buffer.second; - string name = mem_node->var_name; - string type_name = "uint"; - buffer_declarations[binding] = get_name(name, type_name, binding); - } - - string final_source; - for (auto& decl : buffer_declarations) { - final_source += decl; - } - - return final_source; -} - -string GetGroupBufferDeclarations(Kernel *kernel, function get_shared_name) { - string final_source; - - // add group memory declarations - for (auto& mem : kernel->group_memory) { - string name = mem->var_name; - string type_name = type_names[mem->format.type]; - //TODO: add support for non 32 bit types - final_source += get_shared_name(name, type_name, mem->data[0]); - } - - return final_source; -} - -string ReadVariable(Node* node) { - if (node->name == "const") { - return to_string(node->data[0]); - } - if (node->name == "memory") { - return "mem[" + node->var_name + "]"; - } - return node->var_name; -} - -string GetNodeName(const Node* node, bool compact) { - string name = node->var_name; - if (compact) { - if (node->name == "const" && !node->flags.has(NodeProp::Modified)) { - name = node->GetTensor()->GetConstantString(); - } - } - else { - if (node->name == "const") { - name = node->var_name + "(" + node->GetTensor()->GetConstantString() + ")"; - } - } - if (name.empty()) { - name = node->name + "_" + to_string(node->debug_index); - } - return name; -} - -#ifdef HAS_FORMAT -string format_float(double x) { - std::string s = std::format("{}", x); - if (s.find('.') == std::string::npos && s.find('e') == std::string::npos) { - s += '.'; - } - return s + 'f'; -} -#else -string format_float(double value) { - std::ostringstream out; - - // Determine when to use scientific notation vs fixed - bool use_scientific = std::abs(value) < 1e-4 || std::abs(value) > 1e6; - //and if not a zero value - use_scientific = use_scientific && value != 0.0; - if (use_scientific) { - out << std::scientific; // Use scientific notation for very small or large - // numbers - } else { - out << std::fixed; // Use fixed notation for moderate values - } - - out << std::setprecision(7) << value; - - // Convert to string - std::string str = out.str(); - - // Remove trailing zeros and potentially unnecessary decimal point - size_t endpos = str.find_last_not_of('0'); - if (endpos != std::string::npos) { - str = str.substr(0, endpos + 1); - } - if (str.back() == '.') { - str.pop_back(); - } - - // remove all zeros before "e" - size_t epos = str.find('e'); - if (epos != std::string::npos) { - size_t startpos = str.find_last_not_of('0', epos - 1); - if (startpos != std::string::npos) { - str = str.substr(0, startpos + 1) + str.substr(epos); - } - } - - if (str.find('.') == string::npos && str.find('e') == string::npos) { - str += '.'; - } - - // add a zero digit after the decimal point if the next character is not a - // digit - size_t dotpos = str.find('.'); - if (dotpos != std::string::npos && !isdigit(str[dotpos + 1])) { - //add a zero digit after the decimal point - str.insert(dotpos + 1, "0"); - } - - return str + 'f'; -} -#endif - -inline string Tensor::GetConstantString() const { - if (node_->name == "const" || node_->name == "dim_id") { - switch (node_->format.type) { - case TFType::Float: - return format_float(AsFloat(node_->data[0])); - case TFType::Int: - return to_string(AsInt(node_->data[0])); - case TFType::Uint: - return to_string(node_->data[0]) + "u"; - case TFType::Bool: - return node_->data[0] == 0 ? "false" : "true"; - default: - throw std::runtime_error("Unsupported constant type"); - } - } else { - return ""; - } -} - -void CodeGenerator::GenerateKernelCode(Kernel* kernel_) { - kernel = kernel_; - variables = kernel->variables; - read_write_bindings = kernel->read_write_memory; - read_only_bindings = kernel->read_only_memory; - GenerateCode(kernel->root); -} - -void CodeGenerator::GenerateCode(const Node* root) { - int variable_index = 0; - int memory_index = 0; - int prev_depth = 0; - // Translate each operation into HLSL - for (auto node = NodeIterator(root); !node.end(); node->name == "kernel" ? node.forward() : node.next()) { - string name = node->var_name; - - int depth = node.depth() - 1; - if (depth != prev_depth) { - // add scope brackets - if (depth < prev_depth) { - for (int i = prev_depth - 1; i >= depth; i--) { - lines.push_back(new Line(i, "}")); - } - } else if (depth > prev_depth) { - for (int i = prev_depth; i < depth; i++) { - lines.push_back(new Line(i, "{")); - } - } - } - - Line* line = nullptr; - if (custom_generated_code_.contains(*node)) { - line = new Line(*node, "", custom_generated_code_[*node], ";", ""); - } else { - // get node arguments - line = GenerateLine(*node); - } - - if (line == nullptr) { - continue; - } - - line->indent = depth; - lines.push_back(line); - - for (auto additional: additional_lines) { - lines.push_back(new Line(depth, additional)); - } - additional_lines.clear(); - - prev_depth = depth; - } - - // add closing brackets - for (int i = prev_depth - 1; i >= 0; i--) { - lines.push_back(new Line(i, "}")); - } - - //remove lines - unordered_set remove_lines; - for (auto& line : lines) { - if (lines_to_remove.contains(line->node)) { - remove_lines.insert(line); - } - } - - for (auto& line : remove_lines) { - lines.erase(std::remove(lines.begin(), lines.end(), line), lines.end()); - } -} - - -string CodeGenerator::AssembleString() { - string code; - int indent = 0; - for (auto& line : lines) { - for (int i = 0; i < line->indent; i++) { - code += " "; - } - code += line->left; - code += line->expression; - code += line->right; - code += "\n"; - } - return code; -} - -Line* CodeGenerator::GenerateLine(Node* node) { - ArgumentManager& args = node->args; - if (kernel) RegenerateNodeName(node); - GenerateArgumentNames(args); - const Operation* op = node->op; - - string name = node->var_name; - - // get output type - TFDataFormat output_format = node->format; - //TODO: add support for non 32 bit types - - // generate line - string left = ""; - string expression = ""; - string right = ""; - bool needs_paranthesis = false; - - if (op->HasAllTypes(OpProp::Special)) { - int dims = args.Count(ArgType::Shape); - - string shape_arg = "{"; - - for (int j = 0; j < dims; j++) { - if (j != 0) { - shape_arg += ", "; - } - Node* shape_node = args.Get(ArgType::Shape, j); - - shape_arg += "(uint)" + args.Name(ArgType::Shape, dims - j - 1); - } - - shape_arg += "}"; - - if (op->name_ == "loop") { - left += GenerateLoop(&args, name); - } else if (op->name_ == "if") { - left += GenerateIf(&args); - } else if (op->name_ == "memory") { - // if input memory type then just take the input and store it in the - // output - if (node->flags.has(NodeProp::InputMemory)) { - left += "tf.check_tensor(" + node->var_name+ ", \"" + node->var_name + "\", " + shape_arg + ", " + DataFormatNames[output_format] + ")"; - right += ";"; - } - // if any other memory type - allocate it - else { - left += "TFTensor " + node->var_name + " = "; - expression += "tf.allocate(\"" + node->var_name + "\", " + shape_arg + ", " + DataFormatNames[output_format] + ")"; - right += ";"; - } - } else if (op->name_ == "deallocate") { - left = "tf.deallocate(" + args.Name(ArgType::Memory) + ")"; - right = ";"; - } else if (op->name_ == "input_shape") { - left = "int " + node->var_name + " = "; - expression = ir->input_memory_map[(int)node->flags.get(NodeProp::InputShapeMemory)]->var_name + ".shape[" + to_string((int)node->flags.get(NodeProp::InputShapeDim)) + "]"; - right = ";"; - } else if(op->HasAllTypes(OpProp::MemoryReuse)) { - left = "TFTensor " + node->var_name + " = "; - expression = "tf." + op->code_ + "(" + args.Name(ArgType::Memory) + ", \"" + node->var_name + "\", " + shape_arg + ", " + DataFormatNames[output_format] + ")"; - right = ";"; - } else if(op->HasAllTypes(OpProp::Debug)) { - left = "tf." + op->code_ + "(\"" + node->debug_name + "\""; - if (args.Has(ArgType::Input)) { - left += ", " + args.Name(ArgType::Input); - } - left += ")"; - right = ";"; - } else if(op->name_ == "local_memory") { - left = type_names[output_format.type] + " " + name + "[" + to_string(node->data[0]) + "]"; - right = ";"; - } else if(op->name_ == "group_memory") { - left = ""; - //just leave as comment, actual declaration is done outside of the main body of the kernel - right = "//" + type_names[output_format.type] + " " + name + "[" + to_string(node->data[0]) + "]"; - } - } else if (op->HasAllTypes(OpProp::MemoryOp)) { - string address; - - if (kernel) { - address = "0"; - // if has index (not a scalar) - if (args.Has(ArgType::Index)) { - address = args.Name(ArgType::Index); - } - - bool is_local = node->flags.has(NodeProp::LocalMemoryOp); - string memory_name = args.Name(ArgType::Memory) + (is_local ? "" : "_mem"); - string memory_expression = memory_name + "[" + address + "]"; - TFType memory_type = is_local ? node->format.type : Uint; - string memory_type_name = type_names[memory_type]; - - if (op->name_ == "load") { - string output_type_name = type_names[output_format.type]; - left += output_type_name + " " + name + " = "; - expression += - (output_format.type == memory_type) - ? memory_expression - : TypeReinterpret(output_type_name, memory_expression); - right += ";"; - } else if (op->name_ == "store") { - expression += memory_expression + " = "; - expression += - (output_format.type == memory_type) - ? args.Name(ArgType::Input) - : TypeReinterpret(memory_type_name, args.Name(ArgType::Input)); - right += ";"; - } else if (op->HasAllTypes(OpProp::Scatter)) { - if (output_format.type != None) { - left += type_names[output_format.type] + " " + name + " = "; - } - string output_type_name = type_names[output_format.type]; - string input_type_name = type_names[args.Format(ArgType::Input).type]; - expression += GenerateAtomicOp(op->name_, input_type_name, - output_type_name, address, - args.Name(ArgType::Input), name, memory_name); - right += ";"; - } - } else { - string tensor_name = args.Name(ArgType::Memory); - string address = "0"; - if (args.Has(ArgType::Index)) { - address = args.Name(ArgType::Index); - } - - if (op->name_ == "load") { - //do readback - string output_type_name = type_names[output_format.type]; - left += output_type_name + " " + name + " = "; - string memory_expression = GetName("tf.read") + "(" + tensor_name + ", " + address + ")"; - expression += (output_format.type == Uint) - ? memory_expression - : TypeReinterpret(output_type_name, memory_expression); - right += ";"; - } else if (op->name_ == "store") { - //do writeback - string memory_expression = GetName("tf.write") + "(" + tensor_name + ", " + address + ", "; - expression += memory_expression + args.Name(ArgType::Input) + ")"; - right += ";"; - } else if (op->HasAllTypes(OpProp::Scatter)) { - throw std::runtime_error("Scatter operation not supported in non-kernel mode"); - } - } - - } else if (op->name_ == "set") { - left += args.Name(ArgType::Memory) + " = "; - expression += args.Name(ArgType::Input); - right += ";"; - } else { - if (output_format.type != None) { - left += type_names[output_format.type] + " " + name + " = "; - } - string line; - string code = op->code_; - switch (op->class_) { - case OpClass::Operator: - args.AddParenthesis(true); - if ((code == "&" || code == "|") && output_format.type == Bool) { - code = code + code; - } - line += args.Name(ArgType::Input, 0) + " " + code + " " + - args.Name(ArgType::Input, 1); - needs_paranthesis = true; - break; - case OpClass::UnaryOperator: - args.AddParenthesis(true); - line += op->code_ + args.Name(ArgType::Input, 0); - needs_paranthesis = true; - break; - case OpClass::Function: - line += GetName(op->code_) + "("; - for (int i = 0; i < args.Count(ArgType::Input); i++) { - if (i != 0) { - line += ", "; - } - line += args.Name(ArgType::Input, i); - } - line += ")"; - break; - case OpClass::Copy: - line += args.Name(ArgType::Input, 0); - needs_paranthesis = true; - break; - case OpClass::Keyword: - line += op->code_; - break; - case OpClass::DimensionIndex: - line += op->code_ + to_string(node->data[0]); - break; - case OpClass::Variable: - line += op->code_; - break; - case OpClass::TypeCast: - line += GenerateTypeCast(&args, DataTypeToString(node->format.type)); - break; - case OpClass::TypeReinterpret: - line += GenerateTypeReinterpret(&args, DataTypeToString(node->format.type)); - break; - case OpClass::Constant: - line += node->GetTensor()->GetConstantString(); - break; - case OpClass::TernaryOperator: - args.AddParenthesis(true); - line += args.Name(ArgType::Input, 0) + " ? " + - args.Name(ArgType::Input, 1) + " : " + - args.Name(ArgType::Input, 2); - needs_paranthesis = true; - break; - default: - throw std::runtime_error("Unknown operation class"); - break; - } - expression += line; - right += ";"; - } - - node_expression[node] = expression; - requires_paranthesis[node] = needs_paranthesis; - - return new Line(node, left, expression, right, name); -} - -string AddIndent(const string& input, const string& indent) { - stringstream ss(input); - string line; - string indentedText; - - while (getline(ss, line)) { - indentedText += indent + line + "\n"; - } - - return indentedText; -} - -} // namespace TensorFrost - -#endif \ No newline at end of file diff --git a/TensorFrost/Backend/CodeGen/Generators.h b/TensorFrost/Backend/CodeGen/Generators.h deleted file mode 100644 index 2de548ea..00000000 --- a/TensorFrost/Backend/CodeGen/Generators.h +++ /dev/null @@ -1,205 +0,0 @@ -#pragma once -#include -#include -#include -#include -#include -#include - -#if __has_include() - #include - #define HAS_FORMAT 1 -#else - #define HAS_FORMAT 0 -#endif - -#include "Compiler/KernelGen.h" -#include "Tensor/Tensor.h" -#include "Backend/Backend.h" - -namespace TensorFrost { - -string GetNodeName(const Node* node, bool compact = false); -string ReadVariable(Node* node); -void GenerateNodeNames(const IR& ir); - -string GetBufferDeclarations(Kernel* kernel, function get_name); -string GetGroupBufferDeclarations(Kernel* kernel, function get_shared_name); -string GetCPPHeader(); -string GetCPPImplementation(); -string GetHLSLHeader(Kernel* kernel); -string GetGLSLHeader(Kernel* kernel); -void GenerateMain(Program* program, map& dispatch_code, int input_count, int output_count); -void GenerateKernel(Program* program, Kernel* kernel); -void GenerateCPPKernel(Program* program, Kernel* kernel); -void GenerateHLSLKernel(Program* program, Kernel* kernel); -void GenerateGLSLKernel(Program* program, Kernel* kernel); -void GenerateCode(Program* program); - -string GetNodeString(const Node* node, bool verbose = false); -string GetOperationListing(const IR&, bool compact = false, - map invalid = {}); - -bool IsForbiddenName(const string& name); - -using ArgumentNames = map; - -class Line { - public: - Node* node; - string left; - string expression; - string right; - string name; - int indent; - - Line(Node* node, string left, string expression, string right, string name, int indent = 0) - : left(left), right(right), name(name), indent(indent), expression(expression), node(node) {} - - Line(int indent, string expression) - : indent(indent), expression(expression), left(""), right(""), name(""), node(nullptr) {} -}; - -class CodeGenerator { -protected: - unordered_map name_map_; - set used_names_; - public: - list lines; - map custom_generated_code_; - - Kernel* kernel = nullptr; - IR* ir = nullptr; - - CodeGenerator(IR* ir) : ir(ir) {} - - void GenerateKernelCode(Kernel *kernel_); - void GenerateCode(const Node* root); - string AssembleString(); - -protected: - map read_write_bindings; - map read_only_bindings; - map variables; - map node_expression; - map requires_paranthesis; - unordered_set lines_to_remove; - vector additional_lines; - unordered_map name_count; - - virtual void GenerateArgumentNames(ArgumentManager& args) { - for (auto& arg : args.Inputs()) { - Node* node = arg.second; - ArgID id = arg.first; - string name = node->var_name; - bool need_parenthesis = false; - if (variables.contains(node)) { - name = GetName("var") + name; - } else { - string expr = node_expression[node]; - bool is_memory = node->op->HasAllTypes(OpProp::Memory); - bool is_static = node->op->HasAllTypes(OpProp::Static) || - node->op->HasAllTypes(OpProp::CantSubstitute); - bool is_constant = node->op->class_ == OpClass::Constant; - bool is_variable = node->op->class_ == OpClass::Variable; - if (is_constant && expr == "") { - expr = node->GetTensor()->GetConstantString(); - } - bool has_name = node->debug_name != ""; - bool has_single_output = (node->args.OutputCount() == 1) || is_constant || is_variable; - bool modified = node->flags.has(NodeProp::Modified); - bool short_enough = expr.size() < 100; - bool can_substitude = !has_name && has_single_output && !modified && short_enough && !is_static && !is_memory; - if (can_substitude) { - if (expr == "") { - throw std::runtime_error("Substitute expression is empty"); - } - name = expr; - need_parenthesis = requires_paranthesis[node]; - lines_to_remove.insert(node); - } - } - args.SetName(id, name, need_parenthesis); - } - } - - void RegenerateNodeName(Node* node) { - if(node->op->HasAllTypes(OpProp::LocalMemory)) { - used_names_.insert(node->var_name); - return; - } - string debug = node->debug_name; - if (debug.empty()) { - debug = "v" + node->name;//return; - } - if (IsForbiddenName(debug)) { - debug = debug + "0"; - } - // check if the name is already used - if (name_count.contains(debug)) { - name_count[debug]++; - debug = debug + "_" + to_string(name_count[debug]); - } else { - name_count[debug] = 1; - } - if (used_names_.contains(debug)) { - debug = debug + "_"; - } - node->var_name = debug; - } - - Line* GenerateLine(Node* node); - - virtual string GenerateLoop(ArgumentManager* args, const string& name) - { - string in1 = args->Name(ArgType::Input, 0), in2 = args->Name(ArgType::Input, 1), in3 = args->Name(ArgType::Input, 2); - return "for (int " + name + " = " + in1 + "; " + name + " < " + in2 + "; " + name + " += " + in3 + ")"; - } - - virtual string GenerateIf(ArgumentManager* args) - { - return "if (" + args->Name(ArgType::Input, 0) + ")"; - } - - virtual string TypeCast(string type_name, string input) - { - return "((" + type_name + ")(" + input + "))"; - } - - virtual string TypeReinterpret(string type_name, string input) { - return "as" + type_name + "(" + input + ")"; - } - - virtual string GenerateTypeCast(ArgumentManager* args, const string& type_name) - { - return TypeCast(type_name, args->Name(ArgType::Input, 0)); - } - - virtual string GenerateTypeReinterpret(ArgumentManager* args, const string& type_name) - { - return TypeReinterpret(type_name, args->Name(ArgType::Input, 0)); - } - - virtual string GenerateAtomicOp(const string& op, - const string& input_type_name, - const string& output_type_name, - const string& address, const string& input, const string& output, const string& memory_name) - { - return op + "((" + input_type_name + "*)" + memory_name + ", " + address + ", " + input + ")"; - } - - string GetName(const string& name) { - // Check if the function name is in the map - if (name_map_.find(name) != name_map_.end()) { - return name_map_[name]; - } - - // If not, return the original name - return name; - } -}; - -string AddIndent(const string& input, const string& indent); - - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/CodeGen/Langs/CPP.cpp b/TensorFrost/Backend/CodeGen/Langs/CPP.cpp deleted file mode 100644 index a25b5f68..00000000 --- a/TensorFrost/Backend/CodeGen/Langs/CPP.cpp +++ /dev/null @@ -1,781 +0,0 @@ -#include "Backend/CodeGen/Generators.h" - -namespace TensorFrost { -using namespace std; - - -class CPPGenerator : public CodeGenerator { -public: - CPPGenerator(IR* ir) : CodeGenerator(ir) { - name_map_ = { - {"sqrt", "sqrtf"}, - {"var", "var_"}, - }; - } -}; - -string GetCPPHeader() { - string header = R"( -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -typedef uint32_t uint; - -inline int min(int a, int b) -{ - return a < b ? a : b; -} - -inline int max(int a, int b) -{ - return a > b ? a : b; -} - -inline float min(float a, float b) -{ - return a < b ? a : b; -} - -inline float max(float a, float b) -{ - return a > b ? a : b; -} - -inline float asfloat(uint x) -{ - return *(float*)&x; -} - -inline uint asuint(float x) -{ - return *(uint*)&x; -} - -inline uint asuint(double x) -{ - return asuint((float)x); -} - -inline uint asuint(int x) -{ - return *(uint*)&x; -} - -inline uint asuint(uint x) -{ - return *(uint*)&x; -} - -inline int asint(uint x) -{ - return *(int*)&x; -} - -inline uint asuint(bool x) -{ - return *(uint*)&x; -} - -inline bool asbool(uint x) -{ - return *(bool*)&x; -} - -inline int clamp(int x, int a, int b) -{ - return min(max(x, a), b); -} - -inline float clamp(float x, float a, float b) -{ - return min(max(x, a), b); -} - -inline float lerp(float a, float b, float t) -{ - return a + (b - a) * t; -} - -inline float smoothstep(float a, float b, float t) -{ - t = clamp((t - a) / (b - a), 0.0f, 1.0f); - return t * t * (3.0f - 2.0f * t); -} - -inline float sign(float x) -{ - return x < 0.0f ? -1.0f : 1.0f; -} - -inline int sign(int x) -{ - return x < 0 ? -1 : 1; -} - -inline uint reversebits(uint x) -{ - x = (((x & 0xaaaaaaaa) >> 1) | ((x & 0x55555555) << 1)); - x = (((x & 0xcccccccc) >> 2) | ((x & 0x33333333) << 2)); - x = (((x & 0xf0f0f0f0) >> 4) | ((x & 0x0f0f0f0f) << 4)); - x = (((x & 0xff00ff00) >> 8) | ((x & 0x00ff00ff) << 8)); - return ((x >> 16) | (x << 16)); -} - -inline int reversebits(int x) -{ - return (uint)reversebits((uint)x); -} - -inline void InterlockedAdd(int* memory, int address, int value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - place->fetch_add(value, std::memory_order_relaxed); -} - -inline void InterlockedAdd(uint* memory, int address, uint value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - place->fetch_add(value, std::memory_order_relaxed); -} - -inline void InterlockedAdd(float* memory, int address, float value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - float current = place->load(std::memory_order_relaxed); - float goal = current + value; - while (!place->compare_exchange_weak(current, goal, std::memory_order_release, std::memory_order_relaxed)) { - goal = current + value; - } -} - -inline int InterlockedAdd_Prev(int* memory, int address, int value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - return place->fetch_add(value, std::memory_order_relaxed); -} - -inline uint InterlockedAdd_Prev(uint* memory, int address, uint value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - return place->fetch_add(value, std::memory_order_relaxed); -} - -inline float InterlockedAdd_Prev(float* memory, int address, float value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - float current = place->load(std::memory_order_relaxed); - float goal = current + value; - while (!place->compare_exchange_weak(current, goal, std::memory_order_release, std::memory_order_relaxed)) { - goal = current + value; - } - return current; -} - -inline void InterlockedAnd(int* memory, int address, int value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - place->fetch_or(value, std::memory_order_relaxed); -} - -inline void InterlockedAnd(uint* memory, int address, uint value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - place->fetch_and(value, std::memory_order_relaxed); -} - -inline void InterlockedOr(int* memory, int address, int value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - place->fetch_or(value, std::memory_order_relaxed); -} - -inline void InterlockedOr(uint* memory, int address, uint value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - place->fetch_or(value, std::memory_order_relaxed); -} - -inline void InterlockedXor(int* memory, int address, int value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - place->fetch_xor(value, std::memory_order_relaxed); -} - -inline void InterlockedXor(uint* memory, int address, uint value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - place->fetch_xor(value, std::memory_order_relaxed); -} - -inline void InterlockedMin(int* memory, int address, int value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - int current = place->load(std::memory_order_relaxed); - int goal = min(current, value); - while (!place->compare_exchange_weak(current, goal, std::memory_order_release, std::memory_order_relaxed)) { - goal = min(current, value); - } -} - -inline void InterlockedMin(float* memory, int address, float value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - float current = place->load(std::memory_order_relaxed); - float goal = min(current, value); - while (!place->compare_exchange_weak(current, goal, std::memory_order_release, std::memory_order_relaxed)) { - goal = min(current, value); - } -} - -inline void InterlockedMax(int* memory, int address, int value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - int current = place->load(std::memory_order_relaxed); - int goal = max(current, value); - while (!place->compare_exchange_weak(current, goal, std::memory_order_release, std::memory_order_relaxed)) { - goal = max(current, value); - } -} - -inline void InterlockedMax(float* memory, int address, float value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - float current = place->load(std::memory_order_relaxed); - float goal = max(current, value); - while (!place->compare_exchange_weak(current, goal, std::memory_order_release, std::memory_order_relaxed)) { - goal = max(current, value); - } -} - -inline void group_barrier() {} //NOOP, TODO: implement properly - -inline uint pcg(uint v) -{ - uint state = v * 747796405u + 2891336453u; - uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; - return (word >> 22u) ^ word; -} - -inline float pcgf(uint v) -{ - return (float)pcg(v) / (float)0xffffffffu; -} - -extern "C" { - enum TFType { - Float, - Uint, - Int, - Bool, - None, - }; - - struct TFDataFormat { - TFType type; - size_t size; - - bool operator==(const TFDataFormat& other) const { - return type == other.type && size == other.size; - } - - bool operator!=(const TFDataFormat& other) const { - return !(*this == other); - } - - int GetHash() const { - return (int)type << 16 | (int)size; - } - - bool operator<(const TFDataFormat& other) const { - return GetHash() < other.GetHash(); - } - - bool operator>(const TFDataFormat& other) const { - return GetHash() > other.GetHash(); - } - }; - -#define TFTypeNone TFDataFormat{TFType::None, 0} -#define TFTypeBool32 TFDataFormat{TFType::Bool, 32} -#define TFTypeFloat32 TFDataFormat{TFType::Float, 32} -#define TFTypeInt32 TFDataFormat{TFType::Int, 32} -#define TFTypeUint32 TFDataFormat{TFType::Uint, 32} - - struct TFBuffer { - size_t size = 0; - size_t used_size = 0; - size_t time_since_used = 0; - bool up_to_date = false; - bool read_only = false; - const char* name = nullptr; - //add type descriptor (for special kinds of buffers) - }; - - struct TFTensor { - TFBuffer* buffer; - TFDataFormat format; - size_t dim; - const size_t* shape; - }; - - struct TFTensorList { - size_t count; - const TFTensor* tensors; - }; - - struct TFDispatchInfo { - size_t kernel_id; - size_t read_write_count; - const TFTensor* read_write_tensors; - size_t read_only_count; - const TFTensor* read_only_tensors; - size_t variable_count; - const uint32_t* variables; - size_t work_group_count; - }; - - typedef TFTensor alloc_func(const char*, const size_t*, size_t, TFDataFormat, void*); - typedef void dealloc_func(TFTensor, void*); - typedef uint readback_func(TFTensor, size_t, void*); - typedef void writeback_func(TFTensor, size_t, uint32_t, void*); - typedef void dispatch_func(TFDispatchInfo, void*); - typedef void region_func(const char*, bool, void*); - - struct TFRuntime { - alloc_func* alloc; - dealloc_func* dealloc; - readback_func* readback; - writeback_func* writeback; - dispatch_func* dispatch; - region_func* region; - void* custom_data; - }; -} - -class TFContext -{ -public: - TFRuntime runtime; - - TFContext(TFRuntime runtime); - size_t compute_size(const size_t* shape, size_t dim); - TFTensor allocate(std::string name, std::initializer_list shape, TFDataFormat type); - void deallocate(TFTensor tensor); - void check_tensor(TFTensor tensor, std::string name, std::initializer_list target_shape, TFDataFormat target_type); - TFTensor reshape(TFTensor tensor, std::string name, std::initializer_list shape, TFDataFormat type); - TFTensor assert_tensor(TFTensor tensor, std::string name, std::initializer_list target_shape, TFDataFormat target_type); - uint32_t read(TFTensor tensor, size_t index); - void write(TFTensor tensor, size_t index, uint32_t value); - void dispatch(size_t kernel_id, std::initializer_list read_write, std::initializer_list read_only, std::initializer_list var, std::initializer_list shape, std::initializer_list group); - void region_begin(std::string name); - void region_end(std::string name); - template - void print_value(std::string name, T value); - void assert_value(std::string name, bool condition); -}; -)"; - - return header; -} - -string GetCPPImplementation() { - string implementation = R"( - -std::unordered_map TFTypeNames = { - {TFType::Float, "Float"}, {TFType::Uint, "Uint"}, - {TFType::Int, "Int"}, {TFType::Bool, "Bool"}, - {TFType::None, "None"}, -}; - -TFContext::TFContext(TFRuntime runtime) : runtime(runtime) {} - -size_t TFContext::compute_size(const size_t* shape, size_t dim) { - size_t size = 1; - for (size_t i = 0; i < dim; i++) { - size *= shape[i]; - } - return size; -} - -TFTensor TFContext::allocate(std::string name, std::initializer_list shape, TFDataFormat type) -{ - const size_t* shape_arr = shape.begin(); - size_t dim = shape.size(); - size_t size = compute_size(shape_arr, dim); - - for (size_t i = 0; i < dim; i++) { - if(shape_arr[i] < 1) { - throw std::runtime_error("Invalid shape on dimension " + std::to_string(i) + " for " + name + ". Expected positive integer, got " + std::to_string(shape_arr[i])); - } - } - - return runtime.alloc(name.c_str(), shape_arr, dim, type, runtime.custom_data); -} - -void TFContext::deallocate(TFTensor tensor) -{ - runtime.dealloc(tensor, runtime.custom_data); -} - -void TFContext::check_tensor(TFTensor tensor, std::string name, std::initializer_list target_shape, TFDataFormat target_format) -{ - const size_t* shape_arr = tensor.shape; - const size_t* target_shape_arr = target_shape.begin(); - size_t target_dim = target_shape.size(); - - if (tensor.format != target_format) { - throw std::runtime_error("Invalid type for " + name + ". Expected " + TFTypeNames[target_format.type] + ", got " + TFTypeNames[tensor.format.type]); - } - - if (tensor.dim != target_dim) { - throw std::runtime_error("Invalid number of dimensions for " + name + ". Expected " + std::to_string(target_dim) + ", got " + std::to_string(tensor.dim)); - } - - for (size_t i = 0; i < tensor.dim; i++) { - if (shape_arr[i] != target_shape_arr[i] || target_shape_arr[i] < 1) { - throw std::runtime_error("Invalid shape for dimension " + std::to_string(i) + " in " + name + ". Expected " + std::to_string(target_shape_arr[i]) + ", got " + std::to_string(shape_arr[i])); - } - } -} - -TFTensor TFContext::reshape(TFTensor tensor, std::string name, std::initializer_list shape, TFDataFormat type) -{ - size_t* new_shape = new size_t[shape.size()]; - std::copy(shape.begin(), shape.end(), new_shape); - TFTensor new_tensor = {tensor.buffer, type, shape.size(), new_shape}; - - size_t old_size = compute_size(tensor.shape, tensor.dim); - size_t new_size = compute_size(new_tensor.shape, new_tensor.dim); - - if(old_size != new_size) { - throw std::runtime_error("Cannot reshape " + name + ", expected " + std::to_string(new_size) + " elements, while input has " + std::to_string(old_size)); - } - - return new_tensor; -} - -TFTensor TFContext::assert_tensor(TFTensor tensor, std::string name, std::initializer_list target_shape, TFDataFormat target_type) -{ - check_tensor(tensor, name, target_shape, target_type); - return tensor; -} - -uint TFContext::read(TFTensor tensor, size_t index) -{ - return runtime.readback(tensor, index, runtime.custom_data); -} - -void TFContext::write(TFTensor tensor, size_t index, uint32_t value) -{ - runtime.writeback(tensor, index, value, runtime.custom_data); -} - -void TFContext::dispatch(size_t kernel_id, std::initializer_list read_write, std::initializer_list read_only, std::initializer_list var, std::initializer_list shape, std::initializer_list group) -{ - //currently only supports read_write tensors - std::vector all_tensors; - all_tensors.insert(all_tensors.end(), read_write.begin(), read_write.end()); - all_tensors.insert(all_tensors.end(), read_only.begin(), read_only.end()); - std::vector all_vars; - all_vars.insert(all_vars.end(), var.begin(), var.end()); - all_vars.push_back(0); //group index offset - TFDispatchInfo info = {kernel_id, all_tensors.size(), all_tensors.data(), 0, nullptr, (uint)all_vars.size(), all_vars.data(), 0}; - - const TFTensor* read_write_tensors = read_write.begin(); - for (size_t i = 0; i < read_write.size(); i++) { - read_write_tensors[i].buffer->up_to_date = false; - } - - const size_t* shape_arr = shape.begin(); - const size_t* group_arr = group.begin(); - size_t dispatch_dim = shape.size(); - size_t group_dim = group.size(); - - size_t work_group_count = 1; - for (size_t i = 0; i < dispatch_dim - group_dim; i++) { - work_group_count *= shape_arr[i]; - } - - //only the last dimensions are divided by the group size - //TODO: consider reversing indices on backend too - for (size_t i = 0; i < group_dim; i++) { - size_t dim = shape_arr[dispatch_dim - i - 1]; - work_group_count *= (dim + group_arr[i] - 1) / group_arr[i]; - } - - info.work_group_count = work_group_count; - - runtime.dispatch(info, runtime.custom_data); -} - -void TFContext::region_begin(std::string name) -{ - if(runtime.region == nullptr) { - return; - } - runtime.region(name.c_str(), true, runtime.custom_data); -} - -void TFContext::region_end(std::string name) -{ - if(runtime.region == nullptr) { - return; - } - runtime.region(name.c_str(), false, runtime.custom_data); -} - -template -void TFContext::print_value(std::string name, T value) -{ - std::cout << name << ": " << value << std::endl; -} - -void TFContext::assert_value(std::string name, bool condition) -{ - if(!condition) { - throw std::runtime_error("Assertion failed: " + name); - } -} - -)"; - return implementation; -} - -void GenerateCode(Program* program) { - string final_source = GetCPPHeader(); - final_source += GetCPPImplementation(); - - GenerateNodeNames(*program->ir_); - int input_count = (int)program->ir_->input_memory_map.size(); - int output_count = (int)program->ir_->output_memory_map.size(); - - // Generate code for each compute kernel - map dispatch_code; - - for (auto& kernel : program->kernels_) { - global_kernel_manager->AddKernelID(program, &kernel); - kernel.kernel_name_ = "kernel_" + to_string(kernel.kernel_id_); - - // Generate kernel - vector read_write_nodes; - read_write_nodes.resize(kernel.read_write_memory.size()); - for (auto& read_write : kernel.read_write_memory) { - read_write_nodes[read_write.second] = read_write.first; - } - vector read_only_nodes; - read_only_nodes.resize(kernel.read_only_memory.size()); - for (auto& read_only : kernel.read_only_memory) { - read_only_nodes[read_only.second] = read_only.first; - } - - vector variable_nodes; - variable_nodes.resize(kernel.variables.size()); - for (auto& variable : kernel.variables) { - variable_nodes[variable.second] = variable.first; - } - - string read_write_args = "{"; - for (int d = 0; d < read_write_nodes.size(); d++) { - if (d != 0) { - read_write_args += ", "; - } - read_write_args += read_write_nodes[d]->var_name; - } - read_write_args += "}"; - string read_only_args = "{"; - for (int d = 0; d < read_only_nodes.size(); d++) { - if (d != 0) { - read_only_args += ", "; - } - read_only_args += read_only_nodes[d]->var_name; - } - read_only_args += "}"; - - string variable_args = "{"; - for (int d = 0; d < variable_nodes.size(); d++) { - if (d != 0) { - variable_args += ", "; - } - variable_args += "asuint(" + ReadVariable(variable_nodes[d]) + ")"; - } - variable_args += "}"; - - string shape_args = "{"; - int dims = (int)kernel.shape.size(); - for (int d = 0; d < kernel.shape.size(); d++) { - if (d != 0) { - shape_args += ", "; - } - shape_args += "(uint)" + ReadVariable(kernel.shape[ArgID(ArgType::Shape, dims - d - 1)]); - } - shape_args += "}"; - - string group_args = "{"; - for (int d = 0; d < kernel.root->group_size.size(); d++) { - if (d != 0) { - group_args += ", "; - } - group_args += to_string(kernel.root->group_size[d]); - } - group_args += "}"; - - GenerateKernel(program, &kernel); - - if (current_backend == BackendType::CPU) { - final_source += kernel.full_generated_code_; - } - - dispatch_code[kernel.root] = "tf.dispatch(" + to_string(kernel.kernel_id_) + ", " + read_write_args + ", " + read_only_args + ", " + variable_args + ", " + shape_args + ", " + group_args + ")"; - } - - GenerateMain(program, dispatch_code, input_count, output_count); - - final_source += program->main_function_; - - string host_code = - "\n" - "extern \"C\" " -#ifdef _WIN32 - "__declspec(dllexport) " -#endif - "int " - "main" - "(TFTensor* in, TFTensor* out, TFRuntime runtime)\n" - "{\n" - " auto outputs = " + program->program_name + "(TFContext(runtime)"; - - if (input_count > 0) { - host_code += ", "; - } - - for (int i = 0; i < input_count; i++) { - host_code += "in[" + to_string(i) + "]"; - if (i != input_count - 1) { - host_code += ", "; - } - } - host_code += ");\n"; - - for (int i = 0; i < output_count; i++) { - host_code += " out[" + to_string(i) + "] = std::get<" + to_string(i) + ">(outputs);\n"; - } - - host_code += " return 0;\n}\n"; - - final_source += host_code; - - program->generated_code_ = final_source; -} -void GenerateMain(Program* program, map& dispatch_code, - int input_count, int output_count) { - CPPGenerator generator = CPPGenerator(program->ir_); - generator.custom_generated_code_ = dispatch_code; - generator.GenerateCode(program->ir_->root); - - string main_code = "\nstd::tuple<"; - for (int i = 0; i < output_count; i++) { - main_code += "TFTensor"; - if (i != output_count - 1) { - main_code += ", "; - } - } - main_code += "> " + program->program_name + "(TFContext tf"; - if (input_count > 0) { - main_code += ", "; - } - for (int i = 0; i < input_count; i++) { - Node* input_node = program->ir_->input_memory_map[i]; - main_code += "TFTensor " + input_node->var_name; - if (i != input_count - 1) { - main_code += ", "; - } - } - main_code += ")\n{\n"; - - main_code += AddIndent(generator.AssembleString(), " "); - - main_code += " return {"; - - for (int i = 0; i < output_count; i++) { - Node* output_node = program->ir_->output_memory_map[i]; - main_code += output_node->var_name; - if (i != output_count - 1) { - main_code += ", "; - } - } - main_code += "};\n}\n"; - - program->main_function_ = main_code; -} - -void GenerateCPPKernel(Program* program, Kernel* kernel) { - CPPGenerator generator = CPPGenerator(program->ir_); - generator.GenerateKernelCode(kernel); - string kernel_code = generator.AssembleString(); - - string loop = ""; - loop += GetBufferDeclarations(kernel, [](const string& name, const string& type_name, size_t binding) { - return " uint* " + name + "_mem = mem[" + to_string(binding) + "];\n"; - }); - - kernel->var_names = vector(kernel->variables.size()); - kernel->var_types = vector(kernel->variables.size()); - for (auto var : kernel->variables) { - kernel->var_names[var.second] = var.first->var_name; - kernel->var_types[var.second] = type_names[var.first->format.type]; - } - kernel->var_names.push_back("_kernel_block_offset"); - kernel->var_types.push_back(type_names[TFType::Uint]); - for (int i = 0; i < kernel->var_names.size(); i++) { - loop += " " + kernel->var_types[i] + " var_" + kernel->var_names[i] + " = as" + kernel->var_types[i] + "(var[" + to_string(i) + "]);\n"; - } - - loop += " #pragma omp parallel for\n"; - loop += " for (int block_id = var__kernel_block_offset; block_id < (work_group_count+var__kernel_block_offset); block_id++)\n"; - loop += " {\n"; - - loop += GetGroupBufferDeclarations(kernel, [](const string& name, const string& type_name, size_t size) { - return " " + type_name + " " + name + "[" + to_string(size) + "];\n"; - }); - - for (int d = 0; d < kernel->root->group_size.size(); d++) { - int dim = (int)kernel->root->group_size.size() - d - 1; - loop += " for (int block_thread_id" + to_string(dim) + - " = 0; block_thread_id" + to_string(dim) + " < " + - to_string(kernel->root->group_size[dim]) + "; block_thread_id" + - to_string(dim) + "++)\n"; - } - loop += " {\n"; - string loop_end = " }\n"; - loop_end += " }\n"; - loop += AddIndent(kernel_code, " "); - loop += loop_end; - - string kernel_source = - "\n" - "extern \"C\" " - #ifdef _WIN32 - "__declspec(dllexport) " - #endif - "void " + - kernel->kernel_name_ + - "(uint* var, uint** mem, uint work_group_count)\n" - "{\n" + loop + - "}\n"; - - kernel->full_generated_code_ = kernel_source; - kernel->generated_header_ = ""; - kernel->generated_bindings_ = ""; - kernel->generated_main_ = kernel_source; -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/CodeGen/Langs/GLSL.cpp b/TensorFrost/Backend/CodeGen/Langs/GLSL.cpp deleted file mode 100644 index a7a83558..00000000 --- a/TensorFrost/Backend/CodeGen/Langs/GLSL.cpp +++ /dev/null @@ -1,188 +0,0 @@ -#include "Backend/CodeGen/Generators.h" - -namespace TensorFrost { -using namespace std; - -class GLSLGenerator : public CodeGenerator { - public: - GLSLGenerator(IR* ir) : CodeGenerator(ir) { - name_map_ = { - {"var", "var."}, - {"modf", "mod"}, - {"atan2", "atan"}, - {"lerp", "mix"}, - {"reversebits", "bitfieldReverse"}, - {"frac", "fract"}, - {"group_barrier", "barrier"} - }; - } - - string TypeCast(string type_name, string input) override { - return type_name + "(" + input + ")"; - } - - string GenerateAtomicOp(const string& op, const string& input_type_name, - const string& output_type_name, - const string& address, const string& input, const string& output, const string& memory_name) override { - if (op == "InterlockedAdd") { - if(input_type_name == "float") - { - return "atomicAdd_"+memory_name+"(" + address + ", " + input + ")"; - } - return "atomicAdd("+memory_name+"[" + address + "], uint(" + input + "))"; - } else if (op == "InterlockedAdd_Prev") { - if(input_type_name == "float") - { - return output_type_name + "(atomicAdd_"+memory_name+"(" + address + ", " + input +"))"; - } - return output_type_name + "(atomicAdd("+memory_name+"[" + address + "], uint(" + input + ")))"; - } else if (op == "InterlockedMin") { - return "atomicMin("+memory_name+"[" + address + "], uint(" + input + "))"; - } else if (op == "InterlockedMax") { - return "atomicMax("+memory_name+"[" + address + "], uint(" + input + "))"; - } else if (op == "InterlockedAnd") { - return "atomicAnd("+memory_name+"[" + address + "], uint(" + input + "))"; - } else if (op == "InterlockedOr") { - return "atomicOr("+memory_name+"[" + address + "], uint(" + input + "))"; - } else if (op == "InterlockedXor") { - return "atomicXor("+memory_name+"[" + address + "], uint(" + input + "))"; - } - else - { - throw runtime_error("Unsupported atomic operation: " + op); - } - } - -}; - -string GetGLSLHeader(Kernel* kernel) { - string header = R"( -#version 430 - -uint pcg(uint v) { - uint state = v * 747796405u + 2891336453u; - uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; - return (word >> 22u) ^ word; -} - -float pcgf(uint v) { - return float(pcg(v)) / float(0xffffffffu); -} - -float asfloat(uint x) { - return uintBitsToFloat(x); -} - -uint asuint(float x) { - return floatBitsToUint(x); -} - -uint asuint(bool x) { - return uint(x); -} - -uint asuint(int x) { - return uint(x); -} - -uint asuint(uint x) { - return x; -} - -int asint(uint x) { - return int(x); -} - -bool asbool(uint x) { - return bool(x); -} - -)"; - kernel->var_names = vector(kernel->variables.size()); - kernel->var_types = vector(kernel->variables.size()); - header += "\nstruct UBO {\n"; - for (auto var : kernel->variables) { - kernel->var_names[var.second] = var.first->var_name; - kernel->var_types[var.second] = type_names[var.first->format.type]; - } - kernel->var_names.push_back("_kernel_block_offset"); - kernel->var_types.push_back(type_names[TFType::Uint]); - for (int i = 0; i < kernel->var_names.size(); i++) { - header += " " + kernel->var_types[i] + " " + kernel->var_names[i] + ";\n"; - } - header += "};\n\n"; - return header; -} - -string GLSLBufferDeclaration(const string& name, const string& type_name, const size_t binding) { - string decl = "layout(std430, binding = " + to_string(binding) + ") buffer buf_" + name + " {\n " + type_name + " " + name + "_mem[];\n};\n"; - //add atomic functions - decl += R"( -float atomicAdd_)" + name + R"(_mem(int index, float val) { - uint uval = floatBitsToUint(val); - uint tmp0 = 0; - uint tmp1 = 0; - - while (true) { - tmp0 = atomicCompSwap()" + name + R"(_mem[index], tmp1, uval); - if (tmp1 == tmp0) break; - tmp1 = tmp0; - uval = floatBitsToUint(val + uintBitsToFloat(tmp1)); - } - - return uintBitsToFloat(tmp1); -} - -)"; - - return decl; -} - -string GLSLGroupBufferDeclaration(const string& name, const string& type_name, const size_t size) { - string decl = "shared " + type_name + " " + name + "[" + to_string(size) + "];\n"; - return decl; -} - -void GenerateGLSLKernel(Program* program, Kernel* kernel) { - kernel->generated_header_ = GetGLSLHeader(kernel); - - string buffers = GetBufferDeclarations(kernel, GLSLBufferDeclaration) + "\n"; - buffers += "layout(std140) uniform UBOBlock {\n UBO var;\n};\n\n"; - kernel->generated_bindings_ = buffers; - - string main_code = ""; - - main_code += GetGroupBufferDeclarations(kernel, GLSLGroupBufferDeclaration) + "\n"; - - vector group_size = kernel->root->group_size; - //pad with 1s - while (group_size.size() < 3) { - group_size.push_back(1); - } - - main_code += "layout (local_size_x = " + to_string(group_size[0]) + ", local_size_y = " + to_string(group_size[1]) + ", local_size_z = " + to_string(group_size[2]) + ") in;\n"; - - - main_code += R"( -void main() { - int block_id = int(gl_WorkGroupID.x + var._kernel_block_offset); - int block_thread_id0 = int(gl_LocalInvocationID.x); - int block_thread_id1 = int(gl_LocalInvocationID.y); - int block_thread_id2 = int(gl_LocalInvocationID.z); - -)"; - - GLSLGenerator generator = GLSLGenerator(program->ir_); - generator.GenerateKernelCode(kernel); - string kernel_code = generator.AssembleString(); - - main_code += AddIndent(kernel_code, " "); - - main_code += "}\n"; - - kernel->generated_main_ = main_code; - - kernel->full_generated_code_ = kernel->generated_header_ + kernel->generated_bindings_ + kernel->generated_main_; -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/CodeGen/Langs/HLSL.cpp b/TensorFrost/Backend/CodeGen/Langs/HLSL.cpp deleted file mode 100644 index b861b1ca..00000000 --- a/TensorFrost/Backend/CodeGen/Langs/HLSL.cpp +++ /dev/null @@ -1,135 +0,0 @@ -#include "Backend/CodeGen/Generators.h" - -namespace TensorFrost { -using namespace std; - -class HLSLGenerator : public CodeGenerator { - public: - HLSLGenerator(IR* ir) : CodeGenerator(ir) { - name_map_ = { - {"var", "var."}, - {"group_barrier", "GroupMemoryBarrierWithGroupSync"} - }; - } - - string TypeCast(string type_name, string input) override { - return type_name + "(" + input + ")"; - } - - string GenerateAtomicOp(const string& op, const string& input_type_name, - const string& output_type_name, const string& address, - const string& input, const string& output, const string& memory_name) override - { - if (op == "InterlockedAdd") { - if(input_type_name == "float") - { - return "InterlockedAddF("+memory_name+", " + address + ", " + input + ")"; - } - return "InterlockedAdd("+memory_name+"[" + address + "], " + input + ")"; - } else if (op == "InterlockedAdd_Prev") { - if(input_type_name == "float") - { - return "InterlockedAddF("+memory_name+", " + address + ", " + input + ")"; - } - additional_lines.push_back("InterlockedAdd("+memory_name+"[" + address + "], " + - input + ", " + output + ");"); - return "0"; - } else { - return op + "("+memory_name+"[" + address + "], " + input + ")"; - } - } -}; - -string GetHLSLHeader(Kernel* kernel) { - string header =R"( -uint pcg(uint v) -{ - uint state = v * 747796405u + 2891336453u; - uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; - return (word >> 22u) ^ word; -} - -float pcgf(uint v) -{ - return float(pcg(v)) / float(0xffffffffu); -} - -float InterlockedAddF(RWStructuredBuffer buffer, int index, float val) -{ - uint uval = asuint(val), tmp0 = 0, tmp1 = 0; - [allow_uav_condition] while (true) { - InterlockedCompareExchange(buffer[index], tmp0, uval, tmp1); - if (tmp1 == tmp0) break; - tmp0 = tmp1; - uval = asuint(val + asfloat(tmp1)); - } - return asfloat(tmp1); -} - -)"; - kernel->var_names = vector(kernel->variables.size()); - kernel->var_types = vector(kernel->variables.size()); - header += "\nstruct UBO {\n"; - for (auto var : kernel->variables) { - kernel->var_names[var.second] = var.first->var_name; - kernel->var_types[var.second] = type_names[var.first->format.type]; - } - kernel->var_names.push_back("_kernel_block_offset"); - kernel->var_types.push_back(type_names[TFType::Uint]); - for (int i = 0; i < kernel->var_names.size(); i++) { - header += " " + kernel->var_types[i] + " " + kernel->var_names[i] + ";\n"; - } - header += "};\n\n"; - return header; -} - -string HLSLBufferDeclaration(const string& name, const string& type_name, const size_t binding) { - return "RWStructuredBuffer<" + type_name + "> " + name + "_mem : register(u" + to_string(binding) + ");\n"; -} - -string HLSLGroupBufferDeclaration(const string& name, const string& type_name, const size_t size) { - string decl = "groupshared " + type_name + " " + name + "[" + to_string(size) + "];\n"; - return decl; -} - -void GenerateHLSLKernel(Program* program, Kernel* kernel) { - kernel->generated_header_ = GetHLSLHeader(kernel); - - kernel->generated_bindings_ = GetBufferDeclarations(kernel, HLSLBufferDeclaration) + "\n"; - kernel->generated_bindings_ += "cbuffer ubo : register(b0) { UBO var; }\n"; - - vector group_size = kernel->root->group_size; - // pad with 1s - while (group_size.size() < 3) { - group_size.push_back(1); - } - - string main_function = ""; - - main_function += GetGroupBufferDeclarations(kernel, HLSLGroupBufferDeclaration) + "\n"; - - main_function += "[numthreads(" + to_string(group_size[0]) + ", " + to_string(group_size[1]) + ", " + to_string(group_size[2]) + ")]"; - - main_function += R"( -void main(uint3 gtid : SV_GroupThreadID, uint3 gid : SV_GroupID) -{ - int block_id = gid.x + var._kernel_block_offset; - int block_thread_id0 = gtid.x; - int block_thread_id1 = gtid.y; - int block_thread_id2 = gtid.z; - -)"; - - HLSLGenerator generator = HLSLGenerator(program->ir_); - generator.GenerateKernelCode(kernel); - string kernel_code = generator.AssembleString(); - - main_function += AddIndent(kernel_code, " "); - - main_function += "}\n"; - kernel->generated_main_ = main_function; - - kernel->full_generated_code_ = kernel->generated_header_ + kernel->generated_bindings_ + kernel->generated_main_; -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/CodeGen/Langs/Listing.cpp b/TensorFrost/Backend/CodeGen/Langs/Listing.cpp deleted file mode 100644 index 0704ed04..00000000 --- a/TensorFrost/Backend/CodeGen/Langs/Listing.cpp +++ /dev/null @@ -1,155 +0,0 @@ -#include "Backend/CodeGen/Generators.h" -#include "Compiler/KernelGen.h" - -namespace TensorFrost { -using namespace std; - -string GetNodeString(const Node* node, bool verbose) { - string listing = ""; - if (node->format.type != TFType::None) { - listing += DataTypeToString(node->format.type) + "(" + to_string(node->format.size) + ") "; - } - - if (node->format.type != TFType::None) { - // the tensor name - listing += node->var_name + "(" + to_string(node->format.size) + ") = "; - } - - listing += node->name + "("; - - const ArgumentManager& args = node->args; - - auto ArgTypePrint = [&](string name, ArgType type) { - if (args.Has(type)) { - string arr = name + "=["; - for (int i = 0; i < args.Count(type); i++) { - if (i != 0) arr += ","; - arr += GetNodeName(args.Get(type, i), false); - } - arr += "], "; - return arr; - } - return string(); - }; - - listing += ArgTypePrint("memory", ArgType::Memory); - listing += ArgTypePrint("inputs", ArgType::Input); - listing += ArgTypePrint("indices", ArgType::Index); - - if(node->args.OutputCount() > 0) { - listing += "outputs=["; - for(auto [in, out]: node->args.Outputs()) { - listing += GetNodeName(out, false) + ", "; - } - listing += "], "; - } - - if (!node->data.empty()) { - listing += "data=["; - for (int i = 0; i < node->data.size(); i++) { - if (i != 0) listing += ","; - listing += to_string(node->data[i]); - } - listing += "], "; - } - - if(node->flags.count() > 0) { - listing += "flags={"; - auto flags = node->flags.get_data(); - for(auto flad_data : flags) { - NodeProp flag = flad_data.first; - int64_t data = flad_data.second; - listing += NodeFlagsToString(flag); - if(data >= 0) { - listing += "(" + to_string(data) + ")"; - } - listing += ", "; - } - listing += "}, "; - } - - if (node->cost_ >= 0) { - listing += "cost=" + to_string(node->cost_) + ", "; - } - - if(node->memory_deps.size() > 0) { - listing += "memory_deps_count=" + to_string(node->memory_deps.size()) + ", "; - } - - if(node->indexing_mode_ != IndexingMode::Clamp) { - listing += "indexing_mode=" + IndexingModeToString(node->indexing_mode_) + ", "; - } - - listing += ArgTypePrint("shape", ArgType::Shape); - -#ifdef _DEBUG - if (verbose) { - listing += "index=" + to_string(node->index_) + ", "; - listing += "debug_index=" + to_string(node->debug_index) + ", "; - listing += "debug_name=" + node->debug_name + ", "; - listing += "created_in=" + node->created_in_pass + ", "; - listing += "created_in_func=" + node->created_in_function + ", "; - } -#endif - - listing += ")"; - - return listing; -} - -string GetOperationListing(const IR& ir, bool compact, map debug) { - // first give unique names to all the tensors - GenerateNodeNames(ir); - //ClusterProp clusters = ir.GetClusterProperties(); - - // now create the listing - int prev_depth = 0; - string listing; - for (auto node = ir.begin(); !node.end(); node.next()) { - if (compact) { - if (node->name == "const") continue; - } - - if (debug.contains(node.get())) { - listing += "[DEBUG] " + debug[node.get()] + ": \n"; - } - - // indent - int depth = node.depth() - 1; - //add scope brackets - if (depth < prev_depth) { - for (int i = prev_depth - 1; i >= depth; i--) { - for (int j = 0; j < i; j++) { - listing += " "; - } - listing += "}\n"; - } - } - else if (depth > prev_depth) { - for (int i = prev_depth; i < depth; i++) { - for (int j = 0; j < i; j++) { - listing += " "; - } - listing += "{\n"; - } - } - for (int i = 0; i < depth; i++) { - listing += " "; - } - prev_depth = depth; - - listing += GetNodeString(*node, true); - listing += "\n"; - } - - for (int i = prev_depth - 1; i >= 0; i--) { - for (int j = 0; j < i; j++) { - listing += " "; - } - listing += "}\n"; - } - - return listing; -} - -} // namespace TensorFrost diff --git a/TensorFrost/Backend/KernelManager.cpp b/TensorFrost/Backend/KernelManager.cpp deleted file mode 100644 index bf06a59b..00000000 --- a/TensorFrost/Backend/KernelManager.cpp +++ /dev/null @@ -1,38 +0,0 @@ -#include "KernelManager.h" - -namespace TensorFrost { -void KernelManager::AddKernelID(Program *program, Kernel *kernel) { - programs.insert(program); - kernel->kernel_id_ = global_kernel_id++; - kernel_map[kernel->kernel_id_] = kernel; -} - -vector KernelManager::GetAllMainFunctions() { - vector main_functions; - for (auto& program : programs) { - main_functions.push_back(program->main_function_); - } - return main_functions; -} - -vector, vector>>> KernelManager::GetAllKernels() { - vector, vector>>> kernels; - kernels.resize(kernel_map.size()); - for (auto& kernel : kernel_map) { - vector> args; - map memory_bindings = kernel.second->GetMemoryBindings(); - args.resize(memory_bindings.size()); - for (auto& [mem_node, binding] : memory_bindings) { - string name = mem_node->var_name + "_mem"; - string type_name = "uint"; - args[binding] = {name, type_name}; - } - tuple code = {kernel.second->generated_header_, kernel.second->generated_bindings_, kernel.second->generated_main_}; - kernels[kernel.first] = {code, args}; - } - return kernels; -} - -KernelManager* global_kernel_manager = nullptr; - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/KernelManager.h b/TensorFrost/Backend/KernelManager.h deleted file mode 100644 index 172788c8..00000000 --- a/TensorFrost/Backend/KernelManager.h +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "Compiler/KernelGen.h" -#include "TensorMemory.h" - -namespace TensorFrost { - -class KernelManager -{ - unordered_set programs; - unordered_map kernel_map; - size_t global_kernel_id = 0; - public: - - KernelManager() = default; - virtual void DispatchKernel(TFDispatchInfo info) = 0; - void AddKernelID(Program* program, Kernel* kernel); - vector GetAllMainFunctions(); - - vector, vector>>> GetAllKernels(); - Kernel* GetKernel(size_t kernel_id) { return kernel_map[kernel_id]; } -}; - -extern KernelManager* global_kernel_manager; - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/RenderDoc.cpp b/TensorFrost/Backend/RenderDoc.cpp deleted file mode 100644 index cd26b747..00000000 --- a/TensorFrost/Backend/RenderDoc.cpp +++ /dev/null @@ -1,56 +0,0 @@ -#include "RenderDoc.h" -#include -#include -#if defined(_WIN32) -#include -#else -#include -#include -#include -#endif - -namespace TensorFrost { - RENDERDOC_API_1_4_2* RDCAPI = nullptr; - - void LoadRDCAPI() - { - if (RDCAPI) - return; - -#if defined(_WIN32) - if (HMODULE mod = GetModuleHandleA("renderdoc.dll")) - { - std::cout << "renderdoc.dll successfully loaded" << std::endl; - pRENDERDOC_GetAPI RENDERDOC_GetAPI = (pRENDERDOC_GetAPI)GetProcAddress(mod, "RENDERDOC_GetAPI"); - RENDERDOC_GetAPI(eRENDERDOC_API_Version_1_4_2, (void**)&RDCAPI); - } -#else - if (void* mod = dlopen("librenderdoc.so", RTLD_NOW | RTLD_NOLOAD)) - { - std::cout << "librenderdoc.so successfully loaded" << std::endl; - pRENDERDOC_GetAPI RENDERDOC_GetAPI = (pRENDERDOC_GetAPI)dlsym(mod, "RENDERDOC_GetAPI"); - RENDERDOC_GetAPI(eRENDERDOC_API_Version_1_4_2, (void**)&RDCAPI); - } -#endif - } - - void StartRenderDocCapture() - { - LoadRDCAPI(); - - if (RDCAPI) - { - std::cout << "RenderDoc capture started" << std::endl; - RDCAPI->StartFrameCapture(NULL, NULL); - } - } - - void EndRenderDocCapture() - { - if (RDCAPI) - { - std::cout << "RenderDoc capture ended" << std::endl; - RDCAPI->EndFrameCapture(NULL, NULL); - } - } -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/RenderDoc.h b/TensorFrost/Backend/RenderDoc.h deleted file mode 100644 index 9a9f2f2d..00000000 --- a/TensorFrost/Backend/RenderDoc.h +++ /dev/null @@ -1,8 +0,0 @@ -#pragma once - -namespace TensorFrost { - -void StartRenderDocCapture(); -void EndRenderDocCapture(); - -} // namespace TensorFrost diff --git a/TensorFrost/Backend/TensorMemory.cpp b/TensorFrost/Backend/TensorMemory.cpp deleted file mode 100644 index 9f439610..00000000 --- a/TensorFrost/Backend/TensorMemory.cpp +++ /dev/null @@ -1,236 +0,0 @@ -#include "TensorMemory.h" - -namespace TensorFrost { - -size_t GetLinearSize(const vector& shape) { - size_t size = 1; - for (size_t dim : shape) { - size *= dim; - } - return size; -} - -vector GetShape(const TFTensor *tensor) { - vector shape; - for (size_t i = 0; i < tensor->dim; i++) { - shape.push_back(tensor->shape[i]); - } - return shape; -} - -size_t GetSize(const TFTensor *tensor) { - size_t size = 1; - for (size_t i = 0; i < tensor->dim; i++) { - size *= tensor->shape[i]; - } - return size; -} - -TFBuffer * TensorMemoryManager::AllocateBuffer(size_t size) { - TFBuffer* buffer = CreateBuffer(size); - if(allocation_history.contains(size)) { - size_t old_delay = GetDeallocationDelay(size); - allocation_delay[size] = std::min(std::max(old_delay, tick - allocation_history[size]), MAX_POSSIBLE_UNUSED_TIME); - } else { - allocation_delay[size] = DEFAULT_MAX_UNUSED_TIME; - } - buffers_created++; - //add the buffer to the list of allocated buffers - allocated_buffers[size].insert(buffer); - allocation_history[size] = tick; - return buffer; -} - -TFTensor * TensorMemoryManager::AllocateTensor(const vector &shape, const TFDataFormat type, const char* name) { - size_t size = GetLinearSize(shape); - - if (size == 0) { - throw invalid_argument("Trying to allocate a tensor with size 0"); - } - - TFBuffer* buf = TryAllocateBuffer(size); - buf->read_only = false; - ((TFBufferTemplate*)buf)->UpdateName(name); - return MakeTensor(shape, buf, type); -} - -TFTensor * TensorMemoryManager::AllocateTensorWithData(const vector &shape, const vector &data, - const TFDataFormat type, bool read_only, const char* name) { - TFTensor* tensor_memory = AllocateTensor(shape, type, name); - tensor_memory->buffer->read_only = read_only; - ((TFBufferTemplate*)tensor_memory->buffer)->SetDataAtOffset(0, data); - return tensor_memory; -} - -void TensorMemoryManager::DeallocateTensor(TFTensor tensor) { - DeallocateBuffer(tensor.buffer); -} - -TFTensor * TensorMemoryManager::MakeTensor(size_t *shape, size_t dim, TFBuffer *buf, TFDataFormat type) { - TFTensor* tensor = new TFTensor(); - tensor->buffer = buf; - tensor->dim = dim; - tensor->shape = shape; - tensor->format = type; - return tensor; -} - -TFTensor * TensorMemoryManager::MakeTensor(const vector &shape, TFBuffer *buf, TFDataFormat type) { - size_t* shape_arr = new size_t[shape.size()]; - std::copy(shape.begin(), shape.end(), shape_arr); - return MakeTensor(shape_arr, shape.size(), buf, type); -} - -size_t TensorMemoryManager::GetAllocatedSize() const { - size_t total = 0; - for(auto& [size, buffers]: allocated_buffers) { - total += size * buffers.size(); - } - return total; -} - -size_t TensorMemoryManager::GetUnusedAllocatedSize() const { - size_t total = 0; - for(auto& [size, buffers]: allocated_buffers) { - for(auto& buffer: buffers) { - if(unused_buffers.contains(buffer)) { - total += size; - } - } - } - return total; -} - -void TensorMemoryManager::DeallocateBuffer(TFBuffer *buffer) { - unused_buffers.insert(buffer); - buffer->time_since_used = 0; - buffer->used_size = 0; - buffer->up_to_date = false; - buffer->name = "none"; -} - -void TensorMemoryManager::RemoveBuffer(TFBuffer *buffer) { - size_t size = buffer->size; - allocated_buffers[size].erase(buffer); - unused_buffers.erase(buffer); - DeleteBuffer(buffer); - buffers_removed++; -} - -//#define READBACK_DEBUG - -vector TensorMemoryManager::Readback(const TFTensor *memory) { - vector data(GetSize(memory)); -#ifdef READBACK_DEBUG - cout << "Reading back " << data.size() << " elements from buffer of size " << memory->buffer->size << endl; -#endif - ((TFBufferTemplate*)memory->buffer)->GetDataAtOffset(0, data.size(), data.data()); - return data; -} - -uint TensorMemoryManager::ReadbackValue(const TFTensor *memory, size_t index) { - uint32_t data; - ((TFBufferTemplate*)memory->buffer)->GetDataAtOffset(index, 1, &data); - return data; -} - -void TensorMemoryManager::Writeback(const TFTensor *memory, const vector &data) { - ((TFBufferTemplate*)memory->buffer)->SetDataAtOffset(0, data); -} - -void TensorMemoryManager::WritebackValue(const TFTensor *memory, size_t index, uint32_t value) { - ((TFBufferTemplate*)memory->buffer)->SetDataAtOffset(index, {value}); -} - -void TensorMemoryManager::UpdateTick() { - unordered_set buffers_to_delete; - - for(auto& buffer: unused_buffers) { - size_t buf_size = buffer->size; - if(buffer->time_since_used > GetDeallocationDelay(buf_size)) { - buffers_to_delete.insert(buffer); - } else { - buffer->time_since_used++; - } - } - - //delete all buffers that are marked for deletion - for(auto& buffer: buffers_to_delete) { - RemoveBuffer(buffer); - } - tick++; - -#ifdef DEBUG_DYNAMIC_ALLOCATION - if(tick%2048 == 0) { - if(buffers_created> 0 || buffers_removed > 0) { - cout << "Note: " << buffers_created << " buffers created and " << buffers_removed << " buffers removed in the last 2048 ticks" << endl; - } - buffers_created = 0; - buffers_removed = 0; - } -#endif -} - -size_t TensorMemoryManager::GetDeallocationDelay(size_t buf_size) const { - if(allocation_delay.contains(buf_size)) { - return allocation_delay.at(buf_size); - } - - return DEFAULT_MAX_UNUSED_TIME; -} - -TFBuffer *TensorMemoryManager::TryAllocateBuffer(size_t size) { - //try to find a non-used buffer of the correct size - TFBuffer* buffer = nullptr; - bool found = false; - //find the smallest buffer that is larger than the requested size - size_t min_size = size; - size_t max_size = 16 * size; - //get iterator to the first buffer that is larger than the requested size - auto it = allocated_buffers.lower_bound(min_size); - //if no buffer is larger than the requested size, get the first buffer - if(it == allocated_buffers.end()) { - it = allocated_buffers.begin(); - } - //iterate through the buffers - for(; it != allocated_buffers.end(); it++) { - if(it->first > max_size) { - break; - } - if(it->first < size) { - continue; - } - for(auto buf: it->second) { - if(!unused_buffers.contains(buf)) { - continue; - } - buffer = buf; - found = true; - } - if(found) { - break; - } - } - //if no buffer was found, create a new one - if(!found) { - buffer = AllocateBuffer(size); - } else { - unused_buffers.erase(buffer); - } - buffer->used_size = size; - buffer->time_since_used = 0; - UpdateTick(); - return buffer; -} - -TensorMemoryManager::~TensorMemoryManager() { - for(auto& [size, buffers]: allocated_buffers) { - for(auto& buffer: buffers) { - DeleteBuffer(buffer); - } - } -} - -TensorMemoryManager* global_memory_manager = nullptr; - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/TensorMemory.h b/TensorFrost/Backend/TensorMemory.h deleted file mode 100644 index 546d3b37..00000000 --- a/TensorFrost/Backend/TensorMemory.h +++ /dev/null @@ -1,146 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "../Tensor/Tensor.h" - -//#define DEBUG_DYNAMIC_ALLOCATION - -namespace TensorFrost { - -using namespace std; - -extern "C" { - struct TFBuffer { - size_t size = 0; - size_t used_size = 0; - size_t time_since_used = 0; - bool up_to_date = false; - bool read_only = false; - const char* name = nullptr; - //add type descriptor (for special kinds of buffers) - }; - - struct TFTensor { - TFBuffer* buffer; - TFDataFormat format; - size_t dim; - const size_t* shape; - }; - - struct TFTensorList { - size_t count; - const TFTensor* tensors; - }; - - struct TFDispatchInfo { - size_t kernel_id; - size_t read_write_count; - const TFTensor* read_write_tensors; - size_t read_only_count; - const TFTensor* read_only_tensors; - size_t variable_count; - const uint32_t* variables; - size_t work_group_count; - }; - - typedef TFTensor alloc_func(const char*, const size_t*, size_t, TFDataFormat, void*); - typedef void dealloc_func(TFTensor, void*); - typedef uint readback_func(TFTensor, size_t, void*); - typedef void writeback_func(TFTensor, size_t, uint32_t, void*); - typedef void dispatch_func(TFDispatchInfo, void*); - typedef void region_func(const char*, bool, void*); - - struct TFRuntime { - alloc_func* alloc; - dealloc_func* dealloc; - readback_func* readback; - writeback_func* writeback; - dispatch_func* dispatch; - region_func* region; - void* custom_data; - }; - - typedef void cpu_dispatch_func(const uint32_t* var, uint32_t** mem, uint work_group_count); - typedef void main_func(TFTensor*, TFTensor*, TFRuntime); -} - -class TFBufferTemplate : public -TFBuffer { -public: - TFBufferTemplate(size_t size) : TFBuffer{ size } {} - - virtual void UpdateName(const char* name) { - throw std::runtime_error("UpdateName not implemented"); - } - virtual void SetDataAtOffset(size_t offset, const vector& data) { - throw std::runtime_error("SetDataAtOffset not implemented"); - } - virtual void GetDataAtOffset(size_t offset, size_t size, uint32_t* data) { - throw std::runtime_error("GetDataAtOffset not implemented"); - } -}; - -using uint = unsigned int; - -size_t GetLinearSize(const vector& shape); -vector GetShape(const TFTensor* tensor); -size_t GetSize(const TFTensor* tensor); - -class TensorMemoryManager { -private: - size_t tick = 0; - size_t buffers_created = 0; - size_t buffers_removed = 0; - const size_t DEFAULT_MAX_UNUSED_TIME = 128; - const size_t MAX_POSSIBLE_UNUSED_TIME = 32768; - map> allocated_buffers; - map allocation_history; //stores the last tick when a buffer of a certain size was allocated - map allocation_delay; //stores the time between the last 2 allocations of a buffer size - unordered_set unused_buffers; - - static TFTensor* MakeTensor(size_t* shape, size_t dim, TFBuffer* buf, TFDataFormat type); - static TFTensor* MakeTensor(const vector& shape, TFBuffer* buf, TFDataFormat type); - void UpdateTick(); - size_t GetDeallocationDelay(size_t size) const; - - TFBuffer* AllocateBuffer(size_t size); - TFBuffer* TryAllocateBuffer(size_t size); - void DeallocateBuffer(TFBuffer* buffer); - void RemoveBuffer(TFBuffer* buffer); - -protected: - virtual TFBuffer* CreateBuffer(size_t size) { - throw std::runtime_error("CreateBuffer not implemented"); - } - - virtual void DeleteBuffer(TFBuffer * buffer) { - throw std::runtime_error("DeleteBuffer not implemented"); - } - -public: - virtual vector Readback(const TFTensor* memory); - virtual uint ReadbackValue(const TFTensor* memory, size_t index); - virtual void Writeback(const TFTensor* memory, const vector& data); - virtual void WritebackValue(const TFTensor* memory, size_t index, uint32_t value); - - TFTensor* AllocateTensor(const vector& shape, const TFDataFormat type = TFTypeFloat32, const char* name = nullptr); - TFTensor* AllocateTensorWithData(const vector& shape, const vector& data, const TFDataFormat type = TFTypeFloat32, bool read_only = false, const char* name = nullptr); - void DeallocateTensor(TFTensor tensor); - - size_t GetAllocatedSize() const; - size_t GetUnusedAllocatedSize() const; - - ~TensorMemoryManager(); -}; - - -extern TensorMemoryManager* global_memory_manager; - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/CMakeLists.txt b/TensorFrost/CMakeLists.txt index db1ea482..1fa22052 100644 --- a/TensorFrost/CMakeLists.txt +++ b/TensorFrost/CMakeLists.txt @@ -1,8 +1,29 @@ -file(GLOB_RECURSE TENSORFROST_SOURCE_LIST CONFIGURE_DEPENDS *.cpp) -file(GLOB_RECURSE TENSORFROST_HEADER_LIST CONFIGURE_DEPENDS *.h *.hpp) +set(TF_INC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include) +set(TF_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src) -pybind11_add_module(TensorFrost ${TENSORFROST_SOURCE_LIST} ${TENSORFROST_HEADER_LIST}) +file(GLOB_RECURSE TENSORFROST_SOURCE_LIST CONFIGURE_DEPENDS + ${TF_SRC_DIR}/*.cpp) +file(GLOB_RECURSE TENSORFROST_HEADER_LIST CONFIGURE_DEPENDS + ${TF_INC_DIR}/*.h ${TF_INC_DIR}/*.hpp) + +set(PYBIND_MODULE ${CMAKE_CURRENT_SOURCE_DIR}/PybindModule.cpp) + +# ---- Build pybind11 module ---- +pybind11_add_module(TensorFrost + ${PYBIND_MODULE} + ${TENSORFROST_SOURCE_LIST} + ${TENSORFROST_HEADER_LIST} +) + +# ---- Include directories ---- +target_include_directories(TensorFrost + PRIVATE + ${TF_INC_DIR} + ${Python3_INCLUDE_DIRS} +) + +# ---- Fix for macOS RPATH issues ---- if(APPLE) set(CMAKE_MACOSX_RPATH ON) set_target_properties(TensorFrost PROPERTIES @@ -11,49 +32,47 @@ if(APPLE) ) endif() -# Add GLFW +# ---- libraries ---- target_link_libraries(TensorFrost PRIVATE glfw) -glad_add_library(glad_gl_core_46 SHARED API gl:core=4.6) +glad_add_library(glad_gl_core_46 SHARED API gl:core=4.6) target_link_libraries(TensorFrost PRIVATE glad_gl_core_46) glad_add_library(glad_vulkan_12 REPRODUCIBLE LOADER API vulkan=1.2) target_link_libraries(TensorFrost PRIVATE glad_vulkan_12) -target_include_directories(TensorFrost PRIVATE ${Python3_INCLUDE_DIRS}) - -target_include_directories(TensorFrost PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) - -#add imgui headers -target_include_directories(TensorFrost PRIVATE ${CMAKE_SOURCE_DIR}/external/imgui) -target_include_directories(TensorFrost PRIVATE ${CMAKE_SOURCE_DIR}/external/imgui/backends) - -#add imgui sources -file(GLOB IMGUI_SOURCE_LIST ${CMAKE_SOURCE_DIR}/external/imgui/*.cpp) -file(GLOB IMGUI_BACKEND_SOURCE_LIST ${CMAKE_SOURCE_DIR}/external/imgui/backends/imgui_impl_glfw.cpp ${CMAKE_SOURCE_DIR}/external/imgui/backends/imgui_impl_opengl3.cpp) - +# ---- ImGui ---- +target_include_directories(TensorFrost PRIVATE + ${CMAKE_SOURCE_DIR}/external/imgui + ${CMAKE_SOURCE_DIR}/external/imgui/backends +) +file(GLOB IMGUI_SOURCE_LIST + ${CMAKE_SOURCE_DIR}/external/imgui/*.cpp) +file(GLOB IMGUI_BACKEND_SOURCE_LIST + ${CMAKE_SOURCE_DIR}/external/imgui/backends/imgui_impl_glfw.cpp + ${CMAKE_SOURCE_DIR}/external/imgui/backends/imgui_impl_opengl3.cpp) target_sources(TensorFrost PRIVATE ${IMGUI_SOURCE_LIST} ${IMGUI_BACKEND_SOURCE_LIST}) -#add renderdoc headers +# ---- RenderDoc headers ---- target_include_directories(TensorFrost PRIVATE ${CMAKE_SOURCE_DIR}/external/renderdoc) +# ---- convenience / IDE niceties ---- +source_group(TREE ${TF_SRC_DIR} PREFIX "Source Files" + FILES ${TENSORFROST_SOURCE_LIST}) -add_custom_target(install_python_package ALL - DEPENDS TensorFrost -) +source_group(TREE ${CMAKE_CURRENT_SOURCE_DIR} PREFIX "Source Files" + FILES ${PYBIND_MODULE}) -set(DEBUG_PYTHON_SCRIPT "${CMAKE_SOURCE_DIR}/examples/debug.py") +source_group(TREE ${TF_INC_DIR} PREFIX "Header Files" + FILES ${TENSORFROST_HEADER_LIST}) -set_target_properties(TensorFrost PROPERTIES - VS_DEBUGGER_COMMAND "${Python3_EXECUTABLE}" - VS_DEBUGGER_COMMAND_ARGUMENTS "${DEBUG_PYTHON_SCRIPT}" - VS_DEBUGGER_WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" -) +# ---- VS/debug helpers, profiling flags, etc. (unchanged) ---- +add_custom_target(install_python_package ALL DEPENDS TensorFrost) -# Set /PROFILE for RELWITHDEBINFO +set(DEBUG_PYTHON_SCRIPT "${CMAKE_SOURCE_DIR}/examples/debug.py") set_target_properties(TensorFrost PROPERTIES - LINK_FLAGS_RELWITHDEBINFO "/PROFILE" -) - -source_group(TREE ${CMAKE_CURRENT_SOURCE_DIR} PREFIX "Source Files" FILES ${TENSORFROST_SOURCE_LIST} ${TENSORFROST_HEADER_LIST}) - + VS_DEBUGGER_COMMAND "${Python3_EXECUTABLE}" + VS_DEBUGGER_COMMAND_ARGUMENTS "${DEBUG_PYTHON_SCRIPT}" + VS_DEBUGGER_WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" + LINK_FLAGS_RELWITHDEBINFO "/PROFILE" +) \ No newline at end of file diff --git a/TensorFrost/Compiler/Graph/Arguments.cpp b/TensorFrost/Compiler/Graph/Arguments.cpp deleted file mode 100644 index d9c91b57..00000000 --- a/TensorFrost/Compiler/Graph/Arguments.cpp +++ /dev/null @@ -1,19 +0,0 @@ -#include "Arguments.h" - -namespace TensorFrost { - -void ArgumentManager::ClearOutputs() { - outputs_.clear(); -} - -const map arg_type_names = { - {ArgType::Input, "Input"}, {ArgType::Index, "Index"}, {ArgType::Shape, "Shape"}, - {ArgType::Memory, "Memory"}, {ArgType::None, "None"}, -}; - -string TypeToString(ArgType type) { - return arg_type_names.at(type); -} - - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Compiler/Graph/Arguments.h b/TensorFrost/Compiler/Graph/Arguments.h deleted file mode 100644 index b7bf5698..00000000 --- a/TensorFrost/Compiler/Graph/Arguments.h +++ /dev/null @@ -1,221 +0,0 @@ -#pragma once - -#include "Compiler/Operations.h" -#include "Utility/Utility.h" -#include - -namespace TensorFrost { - -enum class ArgType { - Input, - Index, - Shape, - Memory, - None, - Count, -}; - -class Tensor; -class Node; -string TypeToString(ArgType type); - -//argument type and index -using ArgID = pair; -//input nodes with argument type and index -using NodeArguments = map; -//argument type and input node -using Arg = pair; -//argument type and input/output node - edge of the graph -using ArgEdge = pair; - -struct HashArgID { - size_t operator()(const ArgID& id) const { - return (size_t)id.first + id.second * (size_t)ArgType::Count; - } -}; - -//set of edges -using ArgEdges = set; - -#define MAX_ARGS_PER_TYPE 8 -#define MAX_ARGS ((int)ArgType::Count * MAX_ARGS_PER_TYPE) - -class ArgumentManager { - Node* node_; - bool add_parenthesis = false; - unordered_map argument_types_; - unordered_map argument_counts_; - unordered_map argument_names_; - unordered_map argument_requires_parenthesis_; - unordered_map inputs_; - ArgEdges outputs_; - - void AddOutput(ArgID id, Node* node) { - outputs_.insert({{id, node_}, node}); - } - - void RemoveOutput(ArgID id, Node* node) { - if (!outputs_.contains({{id, node_}, node})) { - throw std::runtime_error("Output does not exist"); - } - outputs_.erase({{id, node_}, node}); - } - - void UpdateOutputs(); - void ClearOutputs(); -public: - ArgumentManager(Node* node) { - if (node == nullptr) { - throw std::runtime_error("Node is null"); - } - this->node_ = node; - } - - void AddParenthesis(bool add) { - add_parenthesis = add; - } - - const unordered_map& Inputs() const { - return inputs_; - } - - const ArgEdges& Outputs() const { - return outputs_; - } - - unordered_map InputsCopy() const { - return inputs_; - } - - ArgEdges OutputsCopy() const { - return outputs_; - } - - size_t OutputCount() const { - return outputs_.size(); - } - - void AddArgument(ArgID id, Node *node); - void AddArgument(ArgType type, int index, Node *node) { - AddArgument(ArgID(type, index), node); - } - - void UpdateArgument(ArgID id, Node *node); - - void AddArguments(NodeArguments new_args) { - for (auto& [id, node] : new_args) { - AddArgument(id, node); - } - } - - void SetName(ArgID id, string name, bool requires_parenthesis = false) { - argument_names_[id] = name; - argument_requires_parenthesis_[id] = requires_parenthesis; - } - - bool Has(ArgID id) const { - return inputs_.find(id) != inputs_.end(); - } - - bool Has(ArgType type, int index = 0) const { - ArgID id = ArgID(type, index); - return Has(id); - } - - Node* Get(ArgType type, int index = 0) const { - ArgID id = ArgID(type, index); - auto Arg = inputs_.find(id); - if (Arg != inputs_.end()) { - return Arg->second; - } else { - throw std::runtime_error("Argument of type " + TypeToString(type) + " at index " + std::to_string(index) + " not found"); - } - } - - void Remove(ArgID id); - void Remove(ArgType type, int index = 0) { - Remove(ArgID(type, index)); - } - - const Tensor *GetTensor(ArgType type, int index = 0) const; - - const Tensor& operator[](int index) const; - - TFDataFormat Format(ArgType type, int index = 0) const { - ArgID id = ArgID(type, index); - auto Arg = argument_types_.find(id); - if (Arg != argument_types_.end()) { - return Arg->second; - } - else { - throw std::runtime_error("Argument format not found"); - } - } - - int Count(ArgType type) const { - auto Arg = argument_counts_.find(type); - if (Arg != argument_counts_.end()) { - return Arg->second; - } - else { - return 0; - } - } - - bool RequiresParenthesis(ArgID id) const { - auto Arg = argument_requires_parenthesis_.find(id); - if (Arg != argument_requires_parenthesis_.end()) { - return Arg->second; - } - else { - return false; - } - } - - string Name(ArgType type, int index = 0) const { - ArgID id = ArgID(type, index); - auto Arg = argument_names_.find(id); - if (Arg != argument_names_.end()) { - string name = Arg->second; - if (add_parenthesis && RequiresParenthesis(id)) { - name = "(" + name + ")"; - } - return name; - } - else { - throw std::runtime_error("Argument name not found"); - } - } - - NodeArguments GetArguments() const { - NodeArguments arguments; - for (auto& [id, node] : inputs_) { - arguments[id] = node; - } - return arguments; - } - - NodeArguments GetArguments(ArgType type) const { - NodeArguments arguments; - for (auto& [id, node] : inputs_) { - if (id.first == type) { - arguments[id] = node; - } - } - return arguments; - } - - map GetTensors(ArgType type) const; - - vector GetTensorVector(ArgType type) const; - - ~ArgumentManager(); - - bool CannotMoveArgument(ArgID id); - bool CannotCopyArgument(ArgID id); - bool IsChangingInput(ArgID arg); - - void RemoveArguments(ArgType arg); -}; - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Compiler/Graph/IR.cpp b/TensorFrost/Compiler/Graph/IR.cpp deleted file mode 100644 index 5eb497c1..00000000 --- a/TensorFrost/Compiler/Graph/IR.cpp +++ /dev/null @@ -1,226 +0,0 @@ -#include "IR.h" - -namespace TensorFrost { - -int IR::max_kernel_memory_dependencies = 12; -int IR::max_allowed_memory_dependencies = 12; - -void IR::RemoveNode(Node* node) { - if (node->valid()) { - //remove itself from all outputs - auto outputs = node->args.Outputs(); - for (auto& [id, output_node] : outputs) { - output_node->args.Remove(id.first); - } - - //remove all inputs - auto inputs = node->args.InputsCopy(); - for (auto& [id, input_node] : inputs) { - node->args.Remove(id); - } - - // if its the child of its parent, update the parent's child - if (node->parent && node->parent->child == node) { - node->parent->child = node->next; - } - - // if child node exists, iterate through it and remove all children - if (node->child) { - vector to_delete; - for (auto child = NodeIterator(node); !child.end(); child.next()) { - to_delete.push_back(*child); - } - for (int i = (int)to_delete.size() - 1; i >= 0; i--) { - RemoveNode(to_delete[i]); - } - } - - // if direct child of its parent - if (node->parent && node->parent->child == node) { - node->parent->child = node->next; - } else if (node->prev) { - node->prev->next = node->next; - } - - node->next->prev = node->prev; - existing_nodes.erase(node); - delete node; - } -} - -#ifdef _RELWITHDEBINFO -//#define PROFILE_COMPILATION -#endif - -void IR::RunCompilationPass(string pass_name, const function& expression, bool print, bool update_graph) { - current_pass = pass_name; -#ifdef PROFILE_COMPILATION - auto start = std::chrono::high_resolution_clock::now(); -#endif - - try { - expression(); - } catch (const std::exception& e) { - CheckIR(pass_name, false, false); - throw std::runtime_error("Error in compilation pass " + pass_name + ": " + e.what()); - } - -#ifdef PROFILE_COMPILATION - auto end = std::chrono::high_resolution_clock::now(); - float duration = (float) std::chrono::duration_cast(end - start).count() / 1000.0f; - - PassStats stats; - stats.pass_name = pass_name; - stats.duration = duration; - stats.node_count = 0; - for (auto node = begin(); !node.end(); node.next()) { - stats.node_count++; - } - pass_stats.push_back(stats); -#endif - - if (update_graph) { - UpdateGraph(); - } - - if (print) { - CheckIR(pass_name, false, false); - } - current_pass = ""; -} - -#define MAX_COMPILATION_ITERATIONS 32 -#define LOAD_FUSION - -bool IR::RunIterativeCompilationPass(string pass_name, int max_iterations, const function& expression, bool print, bool update_graph) { - bool anything_happened = false; - RunCompilationPass(pass_name, [&]() { - bool converged = false; - for (int i = 0; i < max_iterations; i++) { - bool finished = expression(); - if (finished) { - converged = true; - break; - } else { - anything_happened = true; - } - } - if (!converged) { - throw std::runtime_error("Failed to converge on " + pass_name + ", exceeded maximum iterations"); - } - }, print, update_graph); - return anything_happened; -} - -void IR::CompileIR() -{ - // TODO (Moroz): Add auto tests into build system - CheckIR("Input", false, false); - RunCompilationPass("InitialCompilation", [&]() { - RunCompilationPass("GetInputList", [&]() { GetInputList(); }); - RunCompilationPass("OptimizeOperations", [&]() { OptimizeOperations(); }); - RunCompilationPass("RemoveUnusedOperations", [&]() { RemoveUnusedOperations(); }, true); - RunCompilationPass("UnrollLoops", [&]() { UnrollLoops(); }, true); - RunCompilationPass("TryReplaceModificationsWithVersions", [&]() { TryReplaceModificationsWithVersions(); }, true); - RunCompilationPass("RemoveUnusedOperations", [&]() { RemoveUnusedOperations(); }, true); - - RunIterativeCompilationPass("InsertAlgorithmicPrimitives_PreAutodiff", MAX_COMPILATION_ITERATIONS, [&]() { - return InsertAlgorithmicPrimitives(true); - }, true); - - RunIterativeCompilationPass("ComputeAutodiff", MAX_COMPILATION_ITERATIONS, [&]() { - return ComputeAutodiff(); - }); - - RunCompilationPass("RemoveUnusedOperations", [&]() { RemoveUnusedOperations(); }, true); - RunCompilationPass("UnrollAtomicOperations", [&]() { UnrollAtomicOperations(); }); - RunCompilationPass("OptimizeReductions", [&]() { OptimizeReductions(); }, true); - - RunIterativeCompilationPass("InsertAlgorithmicPrimitives_PostAutodiff", MAX_COMPILATION_ITERATIONS, [&]() { - return InsertAlgorithmicPrimitives(false); - }, true); - - RunCompilationPass("TryReplaceModificationsWithVersions", [&]() { TryReplaceModificationsWithVersions(); }); - RunCompilationPass("OptimizeOperations", [&]() { OptimizeOperations(); }); - RunCompilationPass("RemoveUnusedOperations", [&]() { RemoveUnusedOperations(); }, true); - }, true); - - RunCompilationPass("KernelGeneration", [&]() { - RunCompilationPass("SeparateOperationsIntoKernels", [&]() { SeparateOperationsIntoKernels(); }, true); - RunCompilationPass("CheckKernelShapes", [&]() { CheckKernelShapes(); }); - - RunCompilationPass("ReorderOperations", [&]() { ReorderOperations(); }); - RunCompilationPass("MoveShapeOutsideKernels", [&]() { MoveShapeOutsideKernels(); }); - RunCompilationPass("OptimizeKernels", [&]() { OptimizeKernels(); }); - RunCompilationPass("OptimizeHost", [&]() { OptimizeHost(); }); - - RunCompilationPass("UnrollLoops", [&]() { UnrollLoops(4); }); - RunCompilationPass("TryReplaceModificationsWithVersions", [&]() { TryReplaceModificationsWithVersions(); }, true); - RunCompilationPass("RemoveUnusedOperations", [&]() { RemoveUnusedOperations(); }); - RunCompilationPass("CheckKernelShapes", [&]() { CheckKernelShapes(); }); - - RunCompilationPass("UpdateKernelShapes", [&]() { UpdateKernelShapes(); }, true); - }, true); - -#ifdef LOAD_FUSION - RunIterativeCompilationPass("Iterative load fusion", MAX_COMPILATION_ITERATIONS, [&]() { - AddKernelGlobalLoadOperations(); - AddMemoryOpIndices(); - bool no_changes = OptimizeKernelLoadOperations(); - if(!no_changes) { - RemoveUnusedOperations(); - } - return no_changes; - }); -#endif - - RunCompilationPass("Reclusterize", [&]() { - LimitKernelMemoryDependencies(); - }, true); - - RunCompilationPass("FinilizeKernels", [&]() { - RunCompilationPass("AddKernelGlobalStoreOperations", [&]() { AddKernelGlobalStoreOperations(); }); - RunCompilationPass("AddKernelGlobalStoreOperations: RemoveUnusedKernels", [&]() { RemoveUnusedKernels(); }, true); - RunCompilationPass("AddMemoryOpIndices", [&]() { AddMemoryOpIndices(); }); - RunCompilationPass("ReorderOperations", [&]() { ReorderOperations(); }); - RunCompilationPass("OptimizeOperations", [&]() { OptimizeOperations(); }); - RunCompilationPass("AddMemoryOpIndices", [&]() { AddMemoryOpIndices(); }, true); - - //TODO: Add support for unrolling small constant kernel dimensions - //RunCompilationPass("UnrollOperations", [&]() { UnrollOperations(); }, true); - //RunCompilationPass("SqueezeKernelShapes", [&]() { SqueezeKernelShapes(); }); - - RunCompilationPass("FinalizeMemoryIndexing", [&]() { FinalizeMemoryIndexing(); }); - RunCompilationPass("RemoveUnusedOperations", [&]() { RemoveUnusedOperations(); }); - RunCompilationPass("OptimizeKernels", [&]() { OptimizeKernels(); }); - RunCompilationPass("OptimizeHost", [&]() { OptimizeHost(); }); - RunCompilationPass("OptimizeOperations", [&]() { OptimizeOperations(); }); - RunCompilationPass("OptimizeHostValuesWithHints", [&]() { OptimizeHostValuesWithHints(); }); - RunCompilationPass("RemoveUnusedOperations", [&]() { RemoveUnusedOperations(); }); - - RunCompilationPass("RemoveUnusedKernels", [&]() { RemoveUnusedKernels(); }, true); - RunCompilationPass("AddMemoryDeallocation", [&]() { AddMemoryDeallocation(); }, true); - //RunCompilationPass("CheckFinalKernelShapes", [&]() { CheckKernelShapes(); }); - }, true); - - RunCompilationPass("GetOutputList", [&]() { GetOutputList(); }); - RunCompilationPass("ComputeStatistics", [&]() { ComputeStatistics(); }); - -#ifdef PROFILE_COMPILATION - cout << "Profiled compilation passes:" << endl; - for (const PassStats& stats : pass_stats) { - cout << "Pass: " << stats.pass_name << " took " << stats.duration << "ms and processed " << stats.node_count << " nodes" << endl; - } -#endif -} - -int GetAxis(int dims, int axis) -{ - if (axis < 0) - { - axis = dims + axis; - } - return axis; -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Compiler/Graph/IR.h b/TensorFrost/Compiler/Graph/IR.h deleted file mode 100644 index b0629ec9..00000000 --- a/TensorFrost/Compiler/Graph/IR.h +++ /dev/null @@ -1,413 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "Compiler/Operations.h" -#include "Utility/Utility.h" -#include "Node.h" -#include "Scope.h" - -namespace TensorFrost { - -class IR { -public: - Node* root; - NodeIterator cursor; - - IR() { - root = new Node(); - root->index_ = 0; - root->name = "root"; - root->initialize(nullptr, {}, "host", {TFType::None, 0}, existing_nodes, true); - cursor = NodeIterator(root); - } - - ~IR() { - vector to_delete; - for (auto node = begin(); !node.end(); node.next()) { - to_delete.push_back(*node); - } - for (Node* node : to_delete) { - delete node; - } - delete root; - } - - NodeIterator begin() const { - return NodeIterator(root); - } - - Node* AddNode(Tensor* tensor, NodeArguments&& args, string&& name, TFDataFormat type) { - Node* newNode = nullptr; - if (cursor->valid()) { //already initialized, add new node before cursor - newNode = new Node(cursor->prev, cursor->parent); - if (cursor->prev) - cursor->prev->next = newNode; - else if (cursor->parent) - cursor->parent->child = newNode; - cursor->prev = newNode; - newNode->next = *cursor; - newNode->initialize(tensor, std::move(args), std::move(name), type, existing_nodes); - } else { - newNode = cursor.get(); - cursor->initialize(tensor, std::move(args), std::move(name), type, existing_nodes); - cursor.go_to_next(); - } - -#ifndef NDEBUG - newNode->created_in_pass = current_pass; - newNode->created_in_function = current_function; -#endif - - return newNode; - } - - void MoveNodeTo(Node* target_place, Node* note_to_move) { - if (note_to_move->valid()) { - //remove from current position - if (note_to_move->parent && note_to_move->parent->child == note_to_move) { - note_to_move->parent->child = note_to_move->next; - } - else if (note_to_move->prev) { - note_to_move->prev->next = note_to_move->next; - } - note_to_move->next->prev = note_to_move->prev; - - //insert into new position - note_to_move->parent = target_place->parent; - note_to_move->prev = target_place->prev; - note_to_move->next = target_place; - if (target_place->prev) { - target_place->prev->next = note_to_move; - } - else if (target_place->parent) { - target_place->parent->child = note_to_move; - } - target_place->prev = note_to_move; - } - } - - void RemoveNode(Node* node); - - void SetCursor(Node* node) { - if(node != nullptr) { - cursor = NodeIterator(node, root); - } else { - throw std::runtime_error("Cursor cannot be set to null"); - } - } - - bool LimitKernelMemoryDependencies(); - - void UnrollOperations(); - - stack scope_stack; - - void EndScope() { - if (scope_stack.empty()) throw std::runtime_error("No scope to end"); - SetCursor(scope_stack.top()); - scope_stack.pop(); - } - - void BeginScope(Node* node) { - scope_stack.push(*cursor); - SetCursor(node); - } - - void BeginScopeLastChild(Node* node) { - BeginScope(node->GetLastChild()); - } - - void ExecuteExpressionAfter(Node* node, const function&& expression) { - BeginScope(node->next); - expression(); - EndScope(); - } - - void ExecuteExpressionBefore(Node* node, const function&& expression) { - BeginScope(node); - expression(); - EndScope(); - } - - void ExecuteExpressionFirstChild(Node* node, const function&& expression) { - BeginScope(node->child); - expression(); - EndScope(); - } - - void ExecuteExpressionLastChild(Node* node, const function&& expression) { - BeginScopeLastChild(node); - expression(); - EndScope(); - } - - void CheckIR(string name, bool check_clustering, bool check_kernels); - string PrintListing(map node_debug) const; - - string GetNodeListing(Node *node) const; - - map CopyNodes(set nodes_to_copy, - unordered_map argument_replacements, - unordered_map indices, - unordered_set targets, bool must_copy_all); - map CopyComputation(const unordered_set& targets, - const unordered_map& indices); - void GetInputList(); - void GetOutputList(); - void ComputeStatistics(); - void CopyArguments(ArgEdges args_to_copy, Node *cursor); - map CopyNodesWithIndex(unordered_set nodes_to_copy, - unordered_map indices, Node* cursor = nullptr); - void ReorderOperations(); - void MoveShapeOutsideKernels(); - bool OptimizeKernels(); - void OptimizeHost(); - void OptimizeOperations(); - void OptimizeHostValuesWithHints(); - - bool OptimizeKernelLoadOperations(); - void OptimizeReductions(); - - unordered_set GetDependencies(unordered_set nodes); - - void RemoveUnusedOperations(); - - bool InsertAlgorithmicPrimitives(bool skip_differentiable); - void UnrollLoops(int max_iterations = 8); - void UnrollAtomicOperations(); - void TryReplaceModificationsWithVersions(); - bool ComputeAutodiff(); - void SeparateOperationsIntoKernels(); - void ComputeNodeCost(); - - map GetKernelOutputs(Node *kernel); - void AddNodeLoadOperations(Node* node, Node* kernel, Tensors indices); - void AddKernelGlobalLoadOperations(); - void AddMemoryOpIndices(); - void AddKernelGlobalStoreOperations(); - - unordered_set ComputeKernelDependencies(Node* kernel); - - void CheckKernelShapes(); - void UpdateKernelShapes(); - void AddMemoryDeallocation(); - void RunCompilationPass(string pass_name, const function &expression, bool print = false, bool update_graph = false); - - bool RunIterativeCompilationPass(string pass_name, int max_iterations, const function &expression, - bool print = false, - bool update_graph = false); - - void ReplaceDimNodes(Node* kernel, Tensors indices, int dims); - void MultiDimensionalModeIndices(vector& indices, Node* kernel_, - int dims, Tensors kernel_shape); - Tensor* LinearBlockModeIndices(Tensors& indices, Node* kernel_, int dims, - Tensors kernel_shape); - - void ComputeAddress(Node *node, Tensors indices); - - void FinalizeMemoryIndexing(); - void RemoveUnusedKernels(); - void CompileIR(); - - void UpdateIndex() { - int index = 0; - for (auto node = begin(); !node.end(); node.next()) { - node->UpdateEdges(); - node->index_ = index++; - } - - index++; //add root node - - if (index != existing_nodes.size()) { - unordered_set found_nodes; - found_nodes.insert(root); - for (auto node = begin(); !node.end(); node.next()) { - found_nodes.insert(*node); - } - - unordered_set missing_nodes; - for (auto node : existing_nodes) { - if (found_nodes.find(node) == found_nodes.end()) { - missing_nodes.insert(node); - } - } - - string missing_nodes_str = ""; - for (auto node : missing_nodes) { - missing_nodes_str += GetNodeListing(node) + "\n"; - } - - missing_nodes_str += "\n" + PrintListing({}); - - throw std::runtime_error("\n Some nodes got lost during indexing. Expected " + to_string(existing_nodes.size()) + " but got " + to_string(index) + ". Likely invalid graph. Missing nodes:\n" + missing_nodes_str); - } - } - - void UpdateGraph(const Node* uroot = nullptr) { - if (uroot == nullptr) { - uroot = root; - } - - UpdateIndex(); - -#ifdef _DEBUG - map invalid_nodes; - // check if graph is valid - for (auto node = NodeIterator(uroot); !node.end(); node.next()) { - // if there are null inputs throw an error - for (auto& [id, n] : (*node)->args.Inputs()) { - if (n == nullptr) { - throw std::runtime_error("Null input found in node " + (*node)->var_name + ". Likely an icorrectly deleted node."); - } else if (n->index_ > (*node)->index_) { //if input node is after current node, throw an error - invalid_nodes[*node] = "Argument " + TypeToString(id.first) + ":" + - to_string(id.second) + " " + n->var_name + " is after current node"; - } - } - } - - if(invalid_nodes.size() > 0) { -#ifdef NDEBUG - std::string error = "Invalid graph: "; - for (auto [node, message] : invalid_nodes) { - error += GetNodeListing(node) + ": " + message + "\n"; - } - throw std::runtime_error(error); -#else - throw std::runtime_error("Invalid graph: " + PrintListing(invalid_nodes)); -#endif - } - -#endif - - //update modified flags - for (auto node = NodeIterator(uroot); !node.end(); node.next()) { - node->flags.remove(NodeProp::Modified); - //go over all outputs and check if they are modifiers - for (auto [edge, to] : node->args.Outputs()) { - auto& [id, from] = edge; - if (to->op->HasAllTypes(OpProp::Modifier)) { - bool is_memory = false; - if (id.first != ArgType::Memory) { - is_memory = true; - } - if (!is_memory) { - node->flags.set(NodeProp::Modified); - break; - } - } - } - } - } - - vector GetNodesOfType(const string& name) const { - vector result; - for (auto node = begin(); !node.end(); node.next()) { - if (node->name == name) { - result.push_back(*node); - } - } - return result; - } - - template - vector GetNodesOfType(OpProp type, Args... args) const { - vector result; - for (auto node = begin(); !node.end(); node.next()) { - if (node->op->HasAllTypes(type, args...)) { - result.push_back(*node); - } - } - return result; - } - - size_t CountNodesOfType(OpProp type) const { - size_t count = 0; - for (auto node = begin(); !node.end(); node.next()) { - if (node->op->HasAllTypes(type)) { - count++; - } - } - return count; - } - - vector GetChildren(Node* node) const { - vector result; - for (auto child = NodeIterator(node); !child.end(); child.next()) { - result.push_back(*child); - } - return result; - } - - size_t CombineHashes(size_t hash1, size_t hash2) const { - return hash1 ^ (hash2 + 0x9e3779b9 + (hash1 << 6) + (hash1 >> 2)); - } - - size_t GetApproximateStateHash() const { - size_t hash = 0; - for (auto node = begin(); !node.end(); node.next()) { - size_t node_hash1 = std::hash{}(node->name); - size_t node_hash2 = std::hash{}(node->debug_name); - size_t node_hash3 = std::hash{}(node->debug_index); - hash = CombineHashes(hash, CombineHashes(node_hash1, CombineHashes(node_hash2, node_hash3))); - } - return hash; - } - - int input_memory_count = 0; - int output_memory_count = 0; - int temp_memory_count = 0; - - int readbacks = 0; - int writebacks = 0; - - unordered_map> shape_memory_map; - unordered_map input_memory_map; - unordered_map output_memory_map; - - string current_pass = "Tracing initial graph"; - string current_function = "None"; - - struct PassStats { - string pass_name; - float duration; - int node_count; - }; - vector pass_stats; - - void ReplaceArgs(const ArgEdges& edges, const map& replacements) { - edgesToUpdate.insert(edges.begin(), edges.end()); - replacementNodes.insert(replacements.begin(), replacements.end()); - } - - void RemoveNodes(const vector& nodes) { - removedNodes.insert(removedNodes.end(), nodes.begin(), nodes.end()); - } - - void ApplyChanges(bool update_graph = true, const Node *uroot = nullptr); - void ClearChanges(); - - ArgEdges edgesToUpdate{}; - map replacementNodes{}; - vector removedNodes{}; - unordered_set existing_nodes{}; - - static int max_kernel_memory_dependencies; - static int max_allowed_memory_dependencies; -}; - -int GetAxis(int dims, int axis); - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Compiler/Graph/Node.cpp b/TensorFrost/Compiler/Graph/Node.cpp deleted file mode 100644 index 262da7ec..00000000 --- a/TensorFrost/Compiler/Graph/Node.cpp +++ /dev/null @@ -1,365 +0,0 @@ -#include "IR.h" - -namespace TensorFrost { - -int Node::global_index = 0; - -void ArgumentManager::UpdateOutputs() { - for (auto& [id, node] : inputs_) { - node->args.AddOutput(id, node_); - } -} - -const Tensor *ArgumentManager::GetTensor(ArgType type, int index) const { - return Get(type, index)->GetTensor(); -} - -const Tensor & ArgumentManager::operator[](int index) const { - return *GetTensor(ArgType::Input, index); -} - -map ArgumentManager::GetTensors(ArgType type) const { - map tensors; - for (auto& [id, node] : inputs_) { - if (id.first == type) { - tensors[id.second] = node->GetTensor(); - } - } - return tensors; -} - -ArgumentManager::~ArgumentManager() { - -} - -bool ArgumentManager::CannotMoveArgument(ArgID id) { - Node* from = inputs_[id]; - Node* to = node_; - return (id.first == ArgType::Memory && - !to->op->HasAllTypes(OpProp::Set)) || - (id.first == ArgType::Shape && !to->op->HasAllTypes(OpProp::Memory)) || - from->op->HasAllTypes(OpProp::Memory) || - (from->name == "const" && to->op->HasAllTypes(OpProp::Memory)); //FIX THIS -} - -bool ArgumentManager::CannotCopyArgument(ArgID id) { - Node* from = inputs_[id]; - Node* to = node_; - bool shape = id.first == ArgType::Shape; - bool to_memory = to->op->HasAllTypes(OpProp::Memory); - bool shape_not_memory = shape && !to_memory; - bool is_output = from->flags.has(NodeProp::OutputMemory); - bool is_input = from->flags.has(NodeProp::InputMemory); - bool no_fusion = from->flags.has(NodeProp::StopFusion); - bool cannot_copy = id.first == ArgType::Memory || shape_not_memory || - from->op->HasAllTypes(OpProp::Static) || from->op->HasAllTypes(OpProp::Memory) || - from->flags.has(NodeProp::Modified) || is_output || is_input || no_fusion; - return cannot_copy; -} - -bool ArgumentManager::IsChangingInput(ArgID arg) { - return arg.first == ArgType::Memory && - node_->op->HasAllTypes(OpProp::Modifier); -} - -void ArgumentManager::UpdateArgument(ArgID id, Node *node) { - if(node == nullptr) { - throw std::runtime_error("ArgumentManager: Node is null"); - } - if(!Has(id)) { - throw std::runtime_error("ArgumentManager: Argument " + TypeToString(id.first) + ":" + std::to_string(id.second) + " with node " + node->name + " does not exist"); - } - // inputs_[id]->args.RemoveOutput(id, node_); - // inputs_[id] = node; - // argument_types_[id] = node->type; - // node->args.AddOutput(id, node_); - Remove(id); - AddArgument(id, node); -} - -Node * Node::GetChild(string name) { - for(NodeIterator it = NodeIterator(this); !it.end(); it.next()) { - if(it->name == name) { - return it.get(); - } - } - return nullptr; -} - -Node * Node::GetNodeWithCommonParent(Node *other) { - for (Node* cur_parent = this; cur_parent != nullptr; cur_parent = cur_parent->parent) { - if (cur_parent->parent == other->parent) { - return cur_parent; - } - } - for (Node* cur_parent = other; cur_parent != nullptr; cur_parent = cur_parent->parent) { - if (cur_parent->parent == this->parent) { - return cur_parent; - } - } - for (Node* cur_parent1 = this; cur_parent1 != nullptr; - cur_parent1 = cur_parent1->parent) { - for (Node* cur_parent2 = other; cur_parent2 != nullptr; - cur_parent2 = cur_parent2->parent) { - if (cur_parent1->parent == cur_parent2->parent) { - return cur_parent1; - } - } - } - throw std::runtime_error("No common parent found"); -} - -Node* Node::GetLastChild() { - NodeIterator it = NodeIterator(this); - for (; !it.end(); it.go_to_next()) {} - return it.get(); -} - -vector Node::GetChildren() { - vector children; - for(NodeIterator it = NodeIterator(this); !it.end(); it.go_to_next()) { - children.push_back(it.get()); - } - return children; -} - -bool Node::HasCommonParents(Node *other, int max_depth) const { - int depth = 0; - for (Node* cur_parent = parent; cur_parent != nullptr; cur_parent = cur_parent->parent) { - if (depth++ > max_depth) { - break; - } - if (!other->HasParent(cur_parent)) { - return false; - } - } - return true; -} - -bool Node::HasParent(string name) { - return GetParent(name) != this; -} - -bool Node::HasChild(string name) { - return GetChild(name) != nullptr; -} - -void Node::ValidateParentShapes() const { - //compare the shape of this node with the shape of all its parents - ShapeInfo shape = ShapeInfo(this); - for (Node* cur_parent = parent; cur_parent != nullptr; cur_parent = cur_parent->parent) { - ShapeInfo parent_shape = ShapeInfo(cur_parent); - ShapeCompareResult result = CompareShape(parent_shape, shape); //must only be broadcastable - if (!result.compatible) { - throw std::runtime_error(MakeNodeErrorMessage("The node " + debug_name + " (" + name + ") has incompatible shapes with its parent " + - cur_parent->debug_name + " (" + cur_parent->name + ")", {this, cur_parent})); - } - } -} - -void Node::SetMemoryType(NodeProp memory_type, int index) { - flags.set(memory_type, (int64_t)index); -} - -void Node::CheckNode() const { - // must have operation - if (op == nullptr) { - throw std::runtime_error("Operation object not found"); - } - - // must have tensor - if (tensor_ == nullptr && !flags.has(NodeProp::IsStatic)) { - throw std::runtime_error("Tensor not found"); - } - - //validate the shape of the node if its not scalar - if (args.Count(ArgType::Shape) > 0) { - ValidateParentShapes(); - } -} - -Node * Node::GetLastVersion(Node *latest_node) { - //find last store/scatter operation - Node* last_modifier = this; - int last_index = -1; - Node* loop_node = latest_node->GetParent("loop"); - bool has_loop = loop_node != latest_node; - for (auto [edge, to] : args.Outputs()) { - auto& [id, from] = edge; - bool is_memory = false; - if (id.first != ArgType::Memory) { - is_memory = true; - } - if (is_memory) { - continue; - } - if (to->op->HasAllTypes(OpProp::Modifier)) { - if (to->index_>last_index) { - // either find the last modifier or the last memory node - // or if there is a loop, find the last modifier inside the loop (i.e. - // the previous iteration's modifier) - // if the loop is scalar, then it doesn't matter - bool before_latest = to->index_ < latest_node->index_; - bool inside_loop = has_loop && to->HasParent(loop_node); - bool not_same = to != latest_node; - if ((before_latest || inside_loop) && not_same) - { - last_index = to->index_; - last_modifier = to; - } - } - } - } - return last_modifier; -} - -Node * Node::GetFinalVersion() { - Node* final_version = this; - int last_index = -1; - for (auto [edge, to] : args.Outputs()) { - auto& [id, from] = edge; - bool is_memory = false; - if (id.first != ArgType::Memory) { - is_memory = true; - } - if (is_memory) { - continue; - } - if (to->op->HasAllTypes(OpProp::Modifier) && !to->op->HasAllTypes(OpProp::MemoryOp)) { - if (to->index_ > last_index) { - last_index = to->index_; - final_version = to; - } - } - } - return final_version; -} - -const map flag_names = { - {NodeProp::Modified, "Modified"}, {NodeProp::Placeholder, "Placeholder"}, - {NodeProp::DetachGrad, "DetachGrad"}, {NodeProp::PassGrad, "PassGrad"}, - {NodeProp::KeepDims, "KeepDims"}, {NodeProp::IsStatic, "IsStatic"}, - {NodeProp::OutputMemory, "OutputMemory"}, {NodeProp::InputMemory, "InputMemory"}, - {NodeProp::InputMemoryList, "InputMemoryList"}, {NodeProp::InputShapeMemory, "InputShapeMemory"}, - {NodeProp::InputShapeDim, "InputShapeDim"}, {NodeProp::NoCopyFusion, "NoCopyFusion"}, - {NodeProp::NoLoadFusion, "NoLoadFusion"}, {NodeProp::StopFusion, "StopFusion"}, {NodeProp::HintMaxValue, "HintMaxValue"}, - {NodeProp::HintMinValue, "HintMinValue"}, {NodeProp::LocalMemoryOp, "LocalMemoryOp"} -}; - -string NodeFlagsToString(NodeProp flags) { - if (!flag_names.contains(flags)) { - throw std::runtime_error("Flag name not defined"); - } - return flag_names.at(flags); -} - -void Node::UpdateEdges() { - if (!child) child = new Node(nullptr, this); - if (!next) next = new Node(this, parent); - if (child->valid()) { - child->parent = this; - } - if (next->valid()) { - next->prev = this; - next->parent = parent; - } -} - -const map indexing_mode_names = { - {IndexingMode::Clamp, "Clamp"}, {IndexingMode::Repeat, "Repeat"},{IndexingMode::Unsafe, "Unsafe"} -}; - -string IndexingModeToString(IndexingMode mode) { - return indexing_mode_names.at(mode); -} - -void Node::initialize(Tensor *tensor, NodeArguments &&new_args, string &&new_name, TFDataFormat new_format, unordered_set& existing_nodes, bool set_static) { - if(valid()) { - throw runtime_error("Node already initialized"); - } - UpdateEdges(); - flags.remove(NodeProp::Placeholder); - - tensor_ = tensor; - format = new_format; - args.AddArguments(std::move(new_args)); - //args.UpdateOutputs(); - flags.set(NodeProp::IsStatic, set_static); - name = std::move(new_name); - op = FindOperation(name); - CheckNode(); - existing_nodes.insert(this); -} - -void Node::CopyProperties(Node *other) { - name = other->name; - debug_name = other->debug_name; - indexing_mode_ = other->indexing_mode_; - group_size = other->group_size; - format = other->format; - - flags.copy_all(other->flags); -} - -void Node::CopyMetadata(Node *other) { - if (other->debug_name != "") { - debug_name = other->debug_name; - } - if(other->indexing_mode_ != IndexingMode::Clamp) { - indexing_mode_ = other->indexing_mode_; - } - group_size = other->group_size; - - flags.copy_all_except(other->flags, {NodeProp::Modified}); -} - -int Node::ComputeDepth(Node *root) const { - int depth = 0; - for (const Node* node = this; node != root; node = node->parent) { - depth++; - } - return depth; -} - -bool Node::HasParent(Node *node) const { - for (Node* cur_parent = parent; cur_parent != nullptr; cur_parent = cur_parent->parent) { - if (cur_parent == node) { - return true; - } - } - return false; -} - -void Node::ReplaceThisWithGivenNode(Node *replacement, int min_index, bool make_modified, bool copy_metadata, set nodes_to_modify) { - try { - for (auto [edge, to] : args.OutputsCopy()) { - auto& [id, from] = edge; - if(nodes_to_modify.size() > 0 && !nodes_to_modify.contains(to->name)) { - continue; - } - if (to->index_ >= min_index) { - if(make_modified) { - replacement->flags.set(NodeProp::Modified); - } - to->args.UpdateArgument(id, replacement); - } - } - - if(copy_metadata) { - replacement->CopyMetadata(this); - this->flags.clear(); - } - } catch (const std::exception& e) { - throw std::runtime_error(MakeNodeErrorMessage("Failed to replace node with another node: " + std::string(e.what()), {this, replacement})); - } -} - -Node * Node::GetParent(string name) { - for (Node* cur_parent = parent; cur_parent != nullptr; cur_parent = cur_parent->parent) { - if (cur_parent->name == name) { - return cur_parent; - } - } - return this; -} -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Compiler/Graph/Node.h b/TensorFrost/Compiler/Graph/Node.h deleted file mode 100644 index 9943da24..00000000 --- a/TensorFrost/Compiler/Graph/Node.h +++ /dev/null @@ -1,289 +0,0 @@ -#pragma once - -#include "Compiler/Operations.h" -#include "Utility/Utility.h" -#include "Arguments.h" - -namespace TensorFrost { - -enum class NodeProp { - Placeholder, - Modified, - IsStatic, - OutputMemory, - InputShapeDim, - InputShapeMemory, - InputMemory, - InputMemoryList, - KeepDims, - DetachGrad, - PassGrad, - NoLoadFusion, - NoCopyFusion, - StopFusion, - HintMaxValue, - HintMinValue, - LocalMemoryOp, - Count, -}; - -enum class MemoryType { - None, - Input, - Output, - Constant, -}; - -enum class IndexingMode { - Unsafe, - Clamp, - Repeat -}; - -struct TFTypeDesc { - TFType type; - IndexingMode indexing_mode; - uint constant_value; - -}; - -string NodeFlagsToString(NodeProp flag); -string IndexingModeToString(IndexingMode mode); -using NodeProps = FlagSet; - -class Node { - static int global_index; - public: - int debug_index = -1; - string name; - string var_name = ""; - string debug_name; - int index_ = -1; - float cost_ = -1.0f; - unordered_set memory_deps; - - //Edge *prev, *next; - // - Node *parent = nullptr, *child = nullptr, *next = nullptr, *prev = nullptr; - const Operation* op; - NodeProps flags; - ArgumentManager args; - const Tensor* tensor_; - TFDataFormat format = {TFType::Float, 32}; - std::vector data; - IndexingMode indexing_mode_; //clamp unless otherwise specified - vector group_size; //kernel properties - -#ifndef NDEBUG - string created_in_pass; - string created_in_function; -#endif - - Node(Node* prev = nullptr, Node* parent = nullptr) : parent(parent), prev(prev), args(this) { - flags.set(NodeProp::Placeholder); - debug_index = global_index++; - indexing_mode_ = IndexingMode::Clamp; - } - - bool valid() { - return !flags.has(NodeProp::Placeholder); - } - - void UpdateEdges(); - - //initialize and create next/child placeholders - void initialize(Tensor* tensor, NodeArguments&& new_args, string&& new_name, TFDataFormat new_format, unordered_set& existing_nodes, bool set_static = false); - - void CopyProperties(Node* other); - void CopyMetadata(Node* other); - - const Tensor* GetTensor() const; - int ComputeDepth(Node* root = nullptr) const; - bool HasParent(Node* node) const; - - /// - /// Make all outputs of this node use the given node as input, assuming that the output is further than min_index - /// - /// - /// - void ReplaceThisWithGivenNode(Node* replacement, int min_index = -1, bool make_modified = false, bool copy_metadata = true, set nodes_to_modify = {}); - - Node* GetParent(string name); - Node* GetChild(string name); - - //get the parent that has a common first parent with another node - Node* GetNodeWithCommonParent(Node* other); - - Node* GetLastChild(); - vector GetChildren(); - - //checks if the other node has all parents as this node - bool HasCommonParents(Node* other, int max_depth = 128) const; - - bool HasParent(string name); - bool HasChild(string name); - - void ValidateParentShapes() const; - - void SetMemoryType(NodeProp memory_type, int index = 0); - void CheckNode() const; - Node* GetLastVersion(Node* latest_node); - Node* GetFinalVersion(); - - ~Node(); -}; - -//NodeIterator is a depth first iterator that iterates through the child nodes of a root node -class NodeIterator { - public: - Node* currentNode; - Node* currentParent; - Node* root; - -#ifndef NDEBUG - int iteration_count = 0; - int parent_inconsistency_count = 0; - unordered_set visited; - vector path; -#endif - - NodeIterator() : currentNode(nullptr), root(nullptr), currentParent(nullptr) {} - NodeIterator(Node* node, Node* root) : currentNode(node), root(root), currentParent(node->parent) {} - NodeIterator(const Node* node, const Node* root) - : currentNode(const_cast(node)), root(const_cast(root)), currentParent(const_cast(node->parent)) {} - NodeIterator(Node* node_root) - : currentNode(node_root->child), root(node_root), currentParent(node_root) {} - NodeIterator(const Node* node_root) - : currentNode(const_cast(node_root->child)), - root(const_cast(node_root)), - currentParent(const_cast(node_root)) {} - - Node* operator*() const { return currentNode; } - void update_current_node(Node* new_node) { - currentNode = new_node; -#ifndef NDEBUG - if (visited.contains(currentNode)) { - throw std::runtime_error("Node already visited, potential cycle in operation graph"); - } - visited.insert(currentNode); - path.push_back(currentNode); - iteration_count++; -#endif - } - - NodeIterator& go_to_next() { - if (!currentNode) { - throw std::runtime_error("Invalid node"); - } - - update_current_node(currentNode->next); -#ifndef NDEBUG - if (currentNode->parent != currentParent) { - parent_inconsistency_count++; - } -#endif - - return *this; - } - - NodeIterator& go_to_parent() { - if (!currentNode) { - throw std::runtime_error("Invalid node"); - } - - update_current_node(currentNode->parent); -#ifndef NDEBUG - currentParent = currentNode->parent; -#endif - - return *this; - } - - NodeIterator& go_to_child() { - if (!currentNode) { - throw std::runtime_error("Invalid node"); - } - -#ifndef NDEBUG - currentParent = currentNode->parent; -#endif - update_current_node(currentNode->child); - - return *this; - } - - NodeIterator& up() { - if (!currentNode) { - throw std::runtime_error("Invalid node"); - } - - if (root == currentNode) { - throw std::runtime_error("Already at root"); - } - - if (currentNode->parent != root) { - go_to_parent(); - } else { - go_to_next(); - } - return *this; - } - - NodeIterator& forward() { - if (!currentNode) { - throw std::runtime_error("Invalid node"); - } - - if (!currentNode->valid()) { - return *this; - } - - if (!currentNode->next->valid()) { // no next, try going up - Node* parent = currentNode->parent; - while (!parent->next->valid() && root != parent) { - parent = parent->parent; - } - if (root != parent) { // go to next sibling - currentNode = parent; - } - } - - // just go to next node and stop if it's the end - go_to_next(); - return *this; - } - - // first child, then next - NodeIterator& next() { - if (!currentNode) { - throw std::runtime_error("Invalid node"); - } - - if (!currentNode->valid()) { - return *this; - } - - if (currentNode->child->valid()) { // has child, go down - go_to_child(); - return *this; - } - - forward(); - - return *this; - } - - bool end() { return !currentNode->valid(); } - - Node* operator->() { return currentNode; } - - Node* get() { return currentNode; } - - int depth() { return currentNode->ComputeDepth(root); } - - bool operator!=(const Node* node) { return currentNode != node; } -}; - -std::string MakeNodeErrorMessage(std::string message, std::initializer_list nodes); - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Compiler/Graph/Scope.cpp b/TensorFrost/Compiler/Graph/Scope.cpp deleted file mode 100644 index 001ffc3e..00000000 --- a/TensorFrost/Compiler/Graph/Scope.cpp +++ /dev/null @@ -1,325 +0,0 @@ -#include "Scope.h" - -namespace TensorFrost { - -inline bool KernelScope::IsBoundary(const Node* input, const Node* output, - ArgType arg_type, bool is_identity) { - const Operation* input_op = input->op; - const Operation* output_op = output->op; - - // if this node loads something from another node, that node must not be in - // this kernel - if (output_op->HasAllTypes(OpProp::Load, OpProp::MemoryOp)) { - return arg_type == ArgType::Memory; - } - - // if we are modifying memory, then the modified memory must not be in the - // kernel - if (output_op->HasAnyType(OpProp::Scatter, OpProp::Store) && - !input_op->HasAnyType(OpProp::Scatter, OpProp::Store)) { - return arg_type == ArgType::Memory; - } - - //if the input is a scatter, then this is a boundary - if (input_op->HasAnyType(OpProp::Scatter)) { - //but multiple scatters can be in the same kernel - return !(output_op->HasAnyType(OpProp::Scatter)); - } - - // shape should not be inside kernels - if (arg_type == ArgType::Shape) { - return true; - } - - return false; -} - -bool KernelScope::IsValid() const { - // check if the scope is valid - if (begin == nullptr || end == nullptr) return false; - - // begin and end must have the same parent - if (begin->parent != end->parent) return false; - - if (begin->index_ < 0 || end->index_ < 0) - throw std::runtime_error("Indices are not computed"); - - // begin must be before or at the same index as end - if (begin->index_ > end->index_) return false; - - // check if the boundary nodes are not in the scope - for (Node* boundary_node : boundary_nodes) { - if (boundary_node->index_ >= begin->index_ && - boundary_node->index_ <= end->index_) { - return false; - } - } - - return true; -} - -KernelScope::KernelScope(Node* node, - unordered_set& output_scopes) : begin(node), end(node) { - scope_shape = ShapeInfo(node); - - // if host only, then this can not be a valid kernel scope - if (node->op->HasAllTypes(OpProp::HostOnly)) { - begin = nullptr; - end = nullptr; - return; - } - - if(node->op->HasAllTypes(OpProp::Modifier, OpProp::MemoryOp) && !node->op->HasAllTypes(OpProp::Scatter)) { - if(scope_shape.dim == 0) { - //must be at 1d scalar to properly generate the kernel with a non-atomic write operation - //technically this is also applicable to scatter operations, but when making an atomic counter, it is sometimes nice for its shape to be implicitly - //determined by the shape of the neighboring tensors - scope_shape.ExpandDimensionsTo(1); - } - } - - // find boundary nodes - bool identity = node->args.Count(ArgType::Index) == 0; - - for (auto& input : node->args.Inputs()) { - // get latest input version - Node* latest = input.second->GetLastVersion(node); - // check if input is the boundary of this kernel - bool is_loop_boundary = latest->index_ > node->index_; - if (IsBoundary(latest, node, input.first.first, identity)) { - if (is_loop_boundary) { - latest = latest->GetParent("loop"); - } - boundary_nodes.insert(latest); - } - } - - pair, bool> all_scopes = ComputeScopes(node); - auto child_scopes = all_scopes.first; - bool host_only = all_scopes.second; - - output_scopes.insert(child_scopes.begin(), child_scopes.end()); - - if(host_only) { - begin = nullptr; - end = nullptr; - return; - } - - int scope_count = (int)child_scopes.size(); - if (scope_count == 0) return; - - //if there is more than one child scope, then this node can not be in the scope - if (scope_count > 1) { - begin = nullptr; - end = nullptr; - return; - } - - KernelScope* child_scope = *child_scopes.begin(); - AddBoundaryNodes(child_scope->boundary_nodes); - - ShapeCompareResult result = - CompareShape(scope_shape, child_scope->scope_shape, true); - - if (result.compatible) { - scope_shape = result.broadcast_shape; - } else { - throw std::runtime_error("Something went wrong"); - } -} - -pair, bool> KernelScope::ComputeScopes(Node *root) { - std::unordered_set scopes; - KernelScope* current_scope = new KernelScope(); - bool host_only = false; - for (auto node = NodeIterator(root); !node.end(); node.go_to_next()) { - std::unordered_set child_scopes; - KernelScope* node_scope = new KernelScope(node.get(), child_scopes); - if (node_scope->IsValid()) { // can be merged - KernelScope* merged = KernelScope::Merge(current_scope, node_scope); - if (merged->IsValid()) { - current_scope = merged; - } else { - bool current_is_valid = current_scope->IsValid(); - if (current_is_valid) { - scopes.insert(current_scope); - } - current_scope = node_scope; - } - } else { // has child kernels - // add all child scopes - scopes.insert(child_scopes.begin(), child_scopes.end()); - // add current scope - bool current_is_valid = current_scope->IsValid(); - if (current_is_valid) { - scopes.insert(current_scope); - } - // create a new empty scope - current_scope = new KernelScope(); - host_only = true; - } - } - if (current_scope->IsValid()) { - scopes.insert(current_scope); - } - return {scopes, host_only}; -} - -KernelScope* KernelScope::Merge(KernelScope* a, KernelScope* b) { - bool a_valid = a->IsValid(); - bool b_valid = b->IsValid(); - - if (!a_valid && !b_valid) - throw std::runtime_error("Invalid kernel scopes for merging"); - - if (!a_valid) return b; - if (!b_valid) return a; - - if (a->end->next != b->begin) - throw std::runtime_error("Trying to merge non-adjacent kernel scopes"); - - ShapeCompareResult result = CompareShape(a->scope_shape, b->scope_shape); - - if (!result.exactly_compatible) return new KernelScope(); - - KernelScope* new_scope = new KernelScope(a->begin, b->end, result.broadcast_shape, a->boundary_nodes); - new_scope->AddBoundaryNodes(b->boundary_nodes); - - if (!new_scope->IsValid()) return new KernelScope(); - - return new_scope; -} - -//if shape nodes are compatible, then return the broadcast shape, if not return nullptr -ShapeDimCompareResult CompareShapeDim(Node* a_node, Node* b_node) { - ShapeDimCompareResult result; - result.compatible = false; - result.broadcast = false; - result.exactly_compatible = true; - result.unroll_compatible = true; - result.a_dim = -1; - result.b_dim = -1; - result.unroll_dim = nullptr; - result.broadcast_dim = nullptr; - - if (a_node->name == "const") result.a_dim = a_node->data[0]; - if (b_node->name == "const") result.b_dim = b_node->data[0]; - - // if one of the nodes is a constant = 1, then it is a broadcast - if ((result.a_dim == 1 || result.b_dim == 1) && !(result.a_dim == 1 && result.b_dim == 1)) { - result.compatible = true; - result.broadcast = true; - result.exactly_compatible = false; - - if (result.a_dim == 1) { - result.unroll_dim = a_node; - result.broadcast_dim = b_node; - } else { - result.unroll_dim = b_node; - result.broadcast_dim = a_node; - } - } - - if(!result.compatible) { - // if a and b are constants, then compare their values - if (result.a_dim != -1 && result.b_dim != -1) { - result.compatible = result.a_dim == result.b_dim; - } - } - - if(!result.compatible) { - // otherwise, if a and b are not the same node then they are not the same - // shape (possibly) - result.compatible = a_node == b_node; - } - - if(result.unroll_dim == nullptr) { - result.unroll_dim = a_node; - } - - if(result.broadcast_dim == nullptr) { - result.broadcast_dim = a_node; - } - - if(result.broadcast) { - if(result.a_dim > MAX_DIM_UNROLL || result.b_dim > MAX_DIM_UNROLL) { - result.unroll_compatible = false; - } - } - - result.exactly_compatible = result.compatible && result.exactly_compatible; - result.unroll_compatible = result.compatible && result.unroll_compatible; - return result; -} - -ShapeCompareResult CompareShape(ShapeInfo& a, ShapeInfo& b, bool throw_error) { - ShapeCompareResult result; - result.compatible = true; - result.exactly_compatible = true; - result.unroll_compatible = true; - result.broadcast = false; - result.a_dim = a.dim; - result.b_dim = b.dim; - result.broadcast_dim = max(a.dim, b.dim); - ShapeInfo& max_dim_shape = a.dim > b.dim ? a : b; - - int min_dim = min(a.dim, b.dim); - - for (int i = 0; i < min_dim; i++) { - Node* a_node = a[i]; - Node* b_node = b[i]; - int broadcast_index = i; - - ShapeDimCompareResult res = CompareShapeDim(a_node, b_node); - - if(!res.compatible) { - result.compatible = false; - if (throw_error) { - if(res.a_dim != -1 && res.b_dim != -1) { - throw std::runtime_error("Shapes are not compatible for nodes: " + a.name + " and " + b.name + " with constant values " + to_string(res.a_dim) + " and " + to_string(res.b_dim) + " at index " + to_string(i)); - } - throw std::runtime_error("Shapes are potentially not compatible for nodes: " + a.name + " and " + b.name + " at index " + to_string(i)); - } - break; - } - - if(res.broadcast) { - result.broadcast = true; - result.broadcast_dims.insert(broadcast_index); - } - - result.unroll_compatible = result.unroll_compatible && res.unroll_compatible; - result.exactly_compatible = result.exactly_compatible && res.exactly_compatible; - result.broadcast_shape.AddShape(broadcast_index, res.broadcast_dim); - result.unroll_shape.AddShape(broadcast_index, res.unroll_dim); - } - - //add the rest of the broadcast shape - if(result.compatible) { - for (int i = min_dim; i < result.broadcast_dim; i++) { - result.broadcast = true; - result.broadcast_shape.AddShape(i, max_dim_shape[i]); - result.unroll_shape.AddShape(i, max_dim_shape[i]); - result.broadcast_dims.insert(i); - } - } - - if((result.broadcast && min_dim > 0) || !result.compatible) { - result.exactly_compatible = false; - } - - if (result.compatible && result.broadcast_shape.dim != result.broadcast_dim) { - throw std::runtime_error("Internal Error: Broadcast shape does not match the broadcast dim"); - } - - return result; -} - -ShapeCompareResult CompareShape(const Node* a, const Node* b, bool throw_error) { - ShapeInfo a_info = ShapeInfo(a); - ShapeInfo b_info = ShapeInfo(b); - return CompareShape(a_info, b_info, throw_error); -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Compiler/Graph/Scope.h b/TensorFrost/Compiler/Graph/Scope.h deleted file mode 100644 index 79ba7c7a..00000000 --- a/TensorFrost/Compiler/Graph/Scope.h +++ /dev/null @@ -1,173 +0,0 @@ -#pragma once - -#include "Compiler/Operations.h" -#include "Utility/Utility.h" -#include "Node.h" - -#define MAX_DIM_UNROLL 4 - -namespace TensorFrost { - -using Tensors = vector; - -class ShapeInfo { - public: - vector> shape; - int dim = 0; - string name; - - ShapeInfo() {} - - ShapeInfo(ShapeInfo* shape_info) { - shape = shape_info->shape; - dim = shape_info->dim; - name = shape_info->name; - } - - ShapeInfo(const Node* node) { - dim = node->args.Count(ArgType::Shape); - for (int i = 0; i < dim; i++) { - AddShape(i, node->args.Get(ArgType::Shape, i)); - } - this->name = node->var_name != "" ? node->var_name : node->name; - } - - Node* operator[](int index) const { - return shape[index].first; - - } - - void AddShape(int index, Node* node) { - if(shape.size() <= index) { - shape.resize(index + 1); - } - shape[index] = {node, false}; - dim = max(dim, index + 1); - } - - bool CheckValidity(bool throw_error = false) const { - for (auto node : shape) { - if(node.first == nullptr) { - if (throw_error) { - throw std::runtime_error("Shape not fully defined"); - } - return false; - } - } - return true; - } - - const Tensor* GetTensor(int index) const { - CheckValidity(true); - return shape[index].first->GetTensor(); - } - - bool IsExpanded(int index) const { - return shape[index].second; - } - - Tensors GetTensors() const { - CheckValidity(true); - Tensors tensors = Tensors(); - for (auto node : shape) { - tensors.push_back(node.first->GetTensor()); - } - return tensors; - } - - NodeArguments GetArguments() const { - CheckValidity(true); - NodeArguments arguments; - for (int i = 0; i < shape.size(); i++) { - arguments[ArgID(ArgType::Shape, i)] = shape[i].first; - } - return arguments; - } - - vector GetShape(int default_value = 256) const; - - static float GetSizeEstimate(ShapeInfo &shape); - - void InsertDim(int index, Node* node, bool expanded = false) { - if (index >= shape.size()+1) { - shape.resize(index + 1); - } - shape.insert(shape.begin() + index, {node, expanded}); - dim++; - } - - void ExpandDimensionsTo(int new_dim); -}; - -struct ShapeCompareResult { - bool compatible; - bool unroll_compatible; - bool exactly_compatible; - ShapeInfo broadcast_shape; - ShapeInfo unroll_shape; - set broadcast_dims; - bool broadcast; - int broadcast_dim; - int a_dim; - int b_dim; - int min_dim; -}; - -struct ShapeDimCompareResult { - bool compatible; - bool unroll_compatible; - bool exactly_compatible; - Node* broadcast_dim; - Node* unroll_dim; - bool broadcast; - int a_dim; - int b_dim; -}; - -ShapeCompareResult CompareShape(const Node* a, const Node* b, - bool throw_error = false); - -ShapeCompareResult CompareShape(ShapeInfo& a, ShapeInfo& b, - bool throw_error = false); - -ShapeDimCompareResult CompareShapeDim(Node* a_node, Node* b_node); - -/// -/// Class to select kernel scopes from the IR graph given the constraints and the root node -/// -class KernelScope { - public: - Node* begin = nullptr; - Node* end = nullptr; - ShapeInfo scope_shape; - unordered_set boundary_nodes; - - static bool IsBoundary(const Node* input, const Node* output, ArgType arg_type, bool is_identity); - - KernelScope() : begin(nullptr), end(nullptr) {} - KernelScope(Node* node, unordered_set& output_scopes); - - KernelScope(Node* begin, Node* end, ShapeInfo shape, unordered_set boundary_nodes) - : begin(begin), end(end), scope_shape(shape), boundary_nodes(boundary_nodes) {} - - void CopyProperties(KernelScope* other) { - begin = other->begin; - end = other->end; - scope_shape = other->scope_shape; - boundary_nodes = other->boundary_nodes; - } - - bool IsValid() const; - - static pair, bool> ComputeScopes(Node *root); - - static KernelScope* Merge(KernelScope* a, KernelScope* b); - - void CreateKernel(); - - void AddBoundaryNodes(unordered_set new_boundary_nodes) { - boundary_nodes.insert(new_boundary_nodes.begin(), new_boundary_nodes.end()); - } -}; - -} // namespace TensorFrost diff --git a/TensorFrost/Compiler/Implementations.cpp b/TensorFrost/Compiler/Implementations.cpp deleted file mode 100644 index a42acdbe..00000000 --- a/TensorFrost/Compiler/Implementations.cpp +++ /dev/null @@ -1,783 +0,0 @@ -#include "Compiler/Implementations.h" - -namespace TensorFrost { - -const Tensor& ReduceGradientToShape(const Tensor& gradient, const Tensor& target) -{ - ShapeCompareResult shape_result = CompareShape(gradient.node_, target.node_); - if (!shape_result.compatible) { - throw std::runtime_error("Autodiff: gradient shape not compatible with target tensor"); - } - - if(!shape_result.broadcast) { - return gradient; - } - - int dim = shape_result.broadcast_dim; - ShapeInfo gradinfo = gradient.GetShapeInfo(); - ShapeInfo targetinfo = target.GetShapeInfo(); - - gradinfo.ExpandDimensionsTo(dim); - targetinfo.ExpandDimensionsTo(dim); - - vector axes_to_reduce; - vector unsqueeze; - for(int i = 0; i < dim; i++) { - int val_a = gradinfo.GetTensor(i)->TryGetConstant(); - int val_b = targetinfo.GetTensor(i)->TryGetConstant(); - bool b_expanded = targetinfo.IsExpanded(i); - if(b_expanded || (val_a != val_b && val_b == 1)) { - axes_to_reduce.push_back(i); - bool should_unsqueeze = i < target.GetDimension(); - unsqueeze.push_back(should_unsqueeze); - } - } - - Tensor* reduced = const_cast(&gradient); - //go in inverse order to keep the dimensions in the same order - for(int i = (int)axes_to_reduce.size() - 1; i >= 0; i--) { - reduced = &Tensor::Sum(*reduced, axes_to_reduce[i]); - if(unsqueeze[i]) { - reduced = &Tensor::Unsqueeze(*reduced, axes_to_reduce[i]); - } - } - -#ifndef NDEBUG - //check if the reduced shape is the same as the target shape - ShapeCompareResult result = CompareShape(reduced->node_, target.node_); - if(!result.compatible) { - throw std::runtime_error("Gradient shape not compatible with target tensor, function ReduceGradientToShape failed"); - } -#endif - - return *reduced; -} - -Tensor* ConstantOutOfBounds(const Tensor* array, Tensors indices, uint constant) { - ShapeInfo shapeinfo = array->GetShapeInfo(); - int dims = shapeinfo.dim; - Tensors shape = shapeinfo.GetTensors(); - Tensor* is_out_of_bounds = &Tensor::Constant(0, TFTypeBool32); - for (int i = 0; i < dims; i++) { - Tensor* is_out = &(*indices[i] < Tensor::Constant(0) || *indices[i] >= *shape[i]); - is_out_of_bounds = &(*is_out_of_bounds || *is_out); - } - is_out_of_bounds->SetDebugName("out_of_bounds"); - Tensor* value = &Tensor::Constant(constant, array->node_->format); - Tensor* loaded = &Tensor::Load(*array, indices); - return &Tensor::select(*is_out_of_bounds, *value, *loaded); -} - - -Tensor* IsOutOfBounds(const Tensor* array, Tensors indices) { - ShapeInfo shapeinfo = array->GetShapeInfo(); - int dims = shapeinfo.dim; - Tensors shape = shapeinfo.GetTensors(); - Tensor* is_out_of_bounds = &Tensor::Constant(0, TFTypeBool32); - for (int i = 0; i < dims; i++) { - Tensor* is_out = &(*indices[i] < Tensor::Constant(0) || *indices[i] >= *shape[i]); - is_out_of_bounds = &(*is_out_of_bounds || *is_out); - } - is_out_of_bounds->SetDebugName("out_of_bounds"); - return is_out_of_bounds; -} - - -map gradient_functions = -{ - //elementwise operations - {"copy", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad); }}, - {"add", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad, grad); }}, - {"sub", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad, -grad); }}, - {"mul", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad * in[1], grad * in[0]); }}, - {"div", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad / in[1], -grad * in[0] / (in[1] * in[1])); }}, - {"neg", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(-grad); }}, - {"exp", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad * out); }}, - {"log", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad / in[0]); }}, - {"sin", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad * Tensor::cos(in[0])); }}, - {"cos", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(-grad * Tensor::sin(in[0])); }}, - {"tan", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad * (Tensor::Constant(1.0f) + out * out)); }}, - {"asin", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad / Tensor::sqrt(Tensor::Constant(1.0f) - out * out)); }}, - {"acos", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(-grad / Tensor::sqrt(Tensor::Constant(1.0f) - out * out)); }}, - {"atan", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad / (Tensor::Constant(1.0f) + out * out)); }}, - {"abs", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad * Tensor::sign(in[0])); }}, - {"sign", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(Tensor::Constant(0.0f)); }}, - {"exp2", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad * Tensor::Constant(log(2.0f)) * out); }}, - {"log2", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad / (in[0] * Tensor::Constant(log(2.0f)))); }}, - {"sqrt", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad / (Tensor::Constant(2.0f) * out)); }}, - {"rsqrt", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(-grad / (Tensor::Constant(2.0f) * in[0] * out)); }}, - {"floor", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(Tensor::Constant(0.0f)); }}, - {"ceil", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(Tensor::Constant(0.0f)); }}, - {"round", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(Tensor::Constant(0.0f)); }}, - {"frac", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad); }}, - {"atan2", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad * in[1] / (in[0] * in[0] + in[1] * in[1]), -grad * in[0] / (in[0] * in[0] + in[1] * in[1])); }}, - {"lerp", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad * in[2], grad * (Tensor::Constant(1.0f) - in[2]), grad * (in[0] - in[1])); }}, - {"max", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(Tensor::select(in[0] > in[1], grad, Tensor::Constant(0.0f)), Tensor::select(in[0] < in[1], grad, Tensor::Constant(0.0f))); }}, - {"min", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(Tensor::select(in[0] < in[1], grad, Tensor::Constant(0.0f)), Tensor::select(in[0] > in[1], grad, Tensor::Constant(0.0f))); }}, - {"pow", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad * in[1] * Tensor::pow(in[0], in[1] - Tensor::Constant(1.0f)), grad * Tensor::log(in[0]) * out); }}, - {"tanh", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad * (Tensor::Constant(1.0f) - out * out)); }}, - {"clamp", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - //clamp = min(max(x, min), max) - Tensor& dc_dx = Tensor::select((in[0] < in[1]) || (in[0] > in[2]), Tensor::Constant(0.0f), grad); - Tensor& dc_dmin = Tensor::select(in[0] < in[1], grad, Tensor::Constant(0.0f)); - Tensor& dc_dmax = Tensor::select(in[0] > in[2], grad, Tensor::Constant(0.0f)); - grads.Add(dc_dx, dc_dmin, dc_dmax); - }}, - {"ternary", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(Tensor::Constant(0.0f), Tensor::select(in[0], grad, Tensor::Constant(0.0f)), Tensor::select(in[0], Tensor::Constant(0.0f), grad)); }}, - {"lerp", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad * in[2], grad * (Tensor::Constant(1.0f) - in[2]), grad * (in[0] - in[1])); }}, - {"step", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(Tensor::Constant(0.0f), Tensor::Constant(0.0f)); }}, - {"modf", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad, Tensor::Constant(0.0f)); }}, - {"fma", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(in[1] * grad, in[0] * grad, grad); }}, - - //matrix operations - {"matmul", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - grads.Add(Tensor::Matmul(grad, Tensor::Transpose(in[1])), Tensor::Matmul(Tensor::Transpose(in[0]), grad)); - }}, - {"transpose", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - grads.Add(Tensor::Transpose(grad, out.axis(1), out.axis(0))); - }}, - {"dot", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - Tensor& unsq_grad = Tensor::Unsqueeze(grad, out.axis()); - grads.Add(unsq_grad * in[1], unsq_grad * in[0]); - }}, - {"unsqueeze", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - grads.Add(Tensor::Squeeze(grad, out.axis())); - }}, - {"squeeze", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - grads.Add(Tensor::Unsqueeze(grad, out.axis())); - }}, - {"dim_sum", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - grads.Add(Tensor::Unsqueeze(grad, out.axis())); - }}, - {"dim_max", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - auto& out_unsq = Tensor::Unsqueeze(out, out.axis()); - auto& grad_unsq = Tensor::Unsqueeze(grad, out.axis()); - grads.Add(Tensor::select(in[0] == out_unsq, grad_unsq, Tensor::Constant(0.0f))); - }}, - {"dim_min", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - auto& out_unsq = Tensor::Unsqueeze(out, out.axis()); - auto& grad_unsq = Tensor::Unsqueeze(grad, out.axis()); - grads.Add(Tensor::select(in[0] == out_unsq, grad_unsq, Tensor::Constant(0.0f))); - }}, - {"dim_prefix_sum", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - //b_i = a_0 + ... + a_i - //db_i/da_j = 1 if i >= j, 0 otherwise - //dL/da_j = sum_i dL/db_i * db_i/da_j - //dL/da_j = sum_i dL/db_i * (i >= j) - //g_i == dL/db_i - //dL/da_j = g_j + g_{j+1} + ... + g_n = g_n + g_{n-1} + ... + g_j - //c_i == g_{n-i} - //dL/da_j = c_0 + c_1 + ... + c_j = prefix_sum(c)_j - grads.Add(Tensor::PrefixSum(Tensor::Reverse(grad, out.axis()), out.axis())); - }}, - {"dim_reverse", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - grads.Add(Tensor::Reverse(grad, out.axis())); - }}, - {"reshape", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - const Tensor* memory_input = in.GetTensor(ArgType::Memory); - grads.Add(ArgType::Memory, 0, Tensor::Reshape(grad, memory_input->GetShape())); - }}, - {"assert", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - const Tensor* memory_input = in.GetTensor(ArgType::Memory); - grads.Add(ArgType::Memory, 0, Tensor::Assert(grad, memory_input->GetShape(), memory_input->GetFormat())); - }}, - //memory operations - {"load", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - //derivative of load is scatter gradient to the load memory addresses - int index_count = in.Count(ArgType::Index); - - Tensors tensor_indices = in.GetTensorVector(ArgType::Index); - const Tensor& curGrad = *grads.GetGrad(ArgType::Memory, 0); - const Tensor& is_out_of_bounds = *IsOutOfBounds(in.GetTensor(ArgType::Memory), tensor_indices); - const Tensor& grad_out_of_bounds = Tensor::select(is_out_of_bounds, Tensor::Constant(0.0f), grad); - Tensor::ScatterAdd(curGrad, grad_out_of_bounds, tensor_indices, out.node_->indexing_mode_); - }}, - {"store", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - //derivative of store is load gradient at the store memory addresses - const Tensor* memory_input = in.GetTensor(ArgType::Memory); - int index_count = in.Count(ArgType::Index); - - Tensors tensor_indices = in.GetTensorVector(ArgType::Index); - const Tensor& memory_grad = *grads.GetGrad(ArgType::Memory, 0); - grads.Add(ArgType::Input, 0, Tensor::Load(memory_grad, tensor_indices, out.node_->indexing_mode_)); - }}, - {"InterlockedAdd", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - //derivative of scatter_add is load gradient at the scatter memory addresses - const Tensor* memory_input = in.GetTensor(ArgType::Memory); - int index_count = in.Count(ArgType::Index); - - Tensors tensor_indices = in.GetTensorVector(ArgType::Index); - const Tensor& memory_grad = *grads.GetGrad(ArgType::Memory, 0); - grads.Add(ArgType::Input, 0, Tensor::Load(memory_grad, tensor_indices, out.node_->indexing_mode_)); - }}, - {"set", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - //derivative of set is the gradient of the setted value to the input - const Tensor& memory_grad = *grads.GetGrad(ArgType::Memory, 0); - grads.Add(ArgType::Input, 0, memory_grad); - }}, - {"detached_grad", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - }}, - {"passthrough_grad", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - grads.Add(grad); - }}, -}; - -VJPGradientFunction GetVJPForOperation(string name) { - if (!gradient_functions.contains(name)) { - throw std::runtime_error("Cannot compute VJP for operation " + name); - } - return gradient_functions[name]; -} - -void RegisterVJP(string name, VJPGradientFunction vjp) { - if (gradient_functions.contains(name)) { - throw std::runtime_error("VJP for operation " + name + " already registered"); - } - gradient_functions[name] = vjp; -} - -bool HasDerivativeImplemented(string name) { - return gradient_functions.contains(name); -} - -Tensor* ComputeReduction(const Tensor* array, int axis, - std::function reduction_op, string debug_name = "", - uint initial = 0, - std::function element_op = nullptr) { - // Get shape of the array - Tensors shape = array->GetShape(); - - axis = GetAxis((int)shape.size(), axis); - - // Get the number of dimensions - int dims = (int)shape.size(); - - Tensors sum_shape = Tensors(); - for (int i = 0; i < dims; i++) { - if (i == axis) { - continue; - } - sum_shape.push_back(shape[i]); - } - - // get indices for all dimensions but the last - Tensors indices = Tensors(); - for (int i = 0; i < dims - 1; i++) { - indices.push_back(&Tensor::Index(sum_shape, i)); - } - - // if no dimensions, then add constant 1 - if (sum_shape.empty()) { - sum_shape.push_back(&Tensor::Constant(1)); - } - - Tensors load_index = Tensors(); - for (int id = 0, d = 0; d < dims; d++) { - if (d == axis) { - load_index.push_back(&Tensor::Constant(sum_shape, 0)); - } else { - load_index.push_back(indices[id++]); - } - } - - // start with the first value - Tensor* reduced = &Tensor::Constant(sum_shape, initial, array->node_->format); - reduced->SetDebugName(debug_name); - - // create a loop over the last dimension starting from the second value - Tensor::Loop(Tensor::Constant(0), *shape[axis], Tensor::Constant(1), - [&](const Tensor& i) { - load_index[axis] = &i; - - // load the value - Tensor* value = &Tensor::Load(*array, load_index, IndexingMode::Unsafe); - - if (element_op != nullptr) { - value = element_op(value); - } - - reduced->Set(*reduction_op(reduced, value)); - }); - - return reduced; -} - -Tensor* ComputeScan(const Tensor* array, int axis, std::function scan_op, string debug_name = "", uint initial = 0) { - // Get shape of the array - Tensors shape = array->GetShape(); - - Tensor* scan_result = &Tensor::Memory(shape, array->node_->format); - - axis = GetAxis((int)shape.size(), axis); - - // Get the number of dimensions - int dims = (int)shape.size(); - - Tensors sum_shape = Tensors(); - for (int i = 0; i < dims; i++) { - if (i == axis) { - continue; - } - sum_shape.push_back(shape[i]); - } - - // get indices for all dimensions but the last - Tensors indices = Tensors(); - for (int i = 0; i < dims - 1; i++) { - indices.push_back(&Tensor::Index(sum_shape, i)); - } - - // if no dimensions, then add constant 1 - if (sum_shape.empty()) { - sum_shape.push_back(&Tensor::Constant(1)); - } - - Tensors load_index = Tensors(); - for (int id = 0, d = 0; d < dims; d++) { - if (d == axis) { - load_index.push_back(&Tensor::Constant(sum_shape, 0)); - } else { - load_index.push_back(indices[id++]); - } - } - - // start with the first value - Tensor* reduced = &Tensor::Constant(sum_shape, initial, array->node_->format); - reduced->SetDebugName(debug_name); - - // create a loop over the last dimension starting from the second value - Tensor::Loop(Tensor::Constant(0), *shape[axis], Tensor::Constant(1), - [&](const Tensor& i) { - load_index[axis] = &i; - // load the value - Tensor* value = &Tensor::Load(*array, load_index, IndexingMode::Unsafe); - reduced->Set(*scan_op(reduced, value)); - Tensor::Store(*scan_result, *reduced, load_index, IndexingMode::Unsafe); - }); - - return scan_result; -} - -Tensor* ComputeSum(const Tensor* array, int axis) { - return ComputeReduction(array, axis, [](Tensor* a, Tensor* b) { - return &(*a + *b); }, "sum"); -} - -Tensor* ComputeNorm(const Tensor* array, int axis) { - return &Tensor::sqrt(Tensor::Sum(*array * *array, axis)); -} - -Tensor* ComputeMean(const Tensor* array, int axis) { - return &(Tensor::Sum(*array, axis) / Tensor::tofloat(*array->GetShape()[axis])); -} - -uint GetInitialMax(TFType type) { - if (type == TFType::Float) { - float init = -FLT_MAX; - return *(uint*)&init; - } - else if (type == TFType::Int) { - int init = INT_MIN; - return *(uint*)&init; - } - return 0; -} - -uint GetInitialMin(TFType type) { - if (type == TFType::Float) { - float init = FLT_MAX; - return *(uint*)&init; - } - else if (type == TFType::Int) { - int init = INT_MAX; - return *(uint*)&init; - } - return 0; -} - -Tensor* ComputeMax(const Tensor* array, int axis) { - uint initial = 0; - if (array->node_->format == TFTypeFloat32) { - float init = -FLT_MAX; - initial = *(uint*)&init; - } - else if (array->node_->format == TFTypeInt32) { - int init = INT_MIN; - initial = *(uint*)&init; - } - return ComputeReduction( - array, axis, [](Tensor* a, Tensor* b) { return &Tensor::max(*a, *b); }, - "max", initial); -} - -Tensor* ComputeMin(const Tensor* array, int axis) { - uint initial = UINT_MAX; - if (array->node_->format == TFTypeFloat32) { - float init = FLT_MAX; - initial = *(uint*)&init; - } - else if (array->node_->format == TFTypeInt32) { - int init = INT_MAX; - initial = *(uint*)&init; - } - return ComputeReduction( - array, axis, [](Tensor* a, Tensor* b) { return &Tensor::min(*a, *b); }, - "min", initial); -} - -Tensor* ComputeProduct(const Tensor* array, int axis) { - uint initial = 1; - if (array->node_->format == TFTypeInt32) { - float init = 1.0f; - initial = *(uint*)&init; - } - return ComputeReduction( - array, axis, [](Tensor* a, Tensor* b) { return &(*a * *b); }, "prod", initial); -} - -Tensor* ComputeAny(const Tensor* array, int axis) { - return ComputeReduction(array, axis, [](Tensor* a, Tensor* b) { return &(*a || *b); }, "any", 0); -} - -Tensor* ComputeAll(const Tensor* array, int axis) { - return ComputeReduction( - array, axis, [](Tensor* a, Tensor* b) { return &(*a && *b); }, "all", ~0); -} - -Tensor* ComputePrefixSum(const Tensor* array, int axis) { - return ComputeScan(array, axis, [](Tensor* a, Tensor* b) { return &(*a + *b); }, "prefix_sum"); -} - -Tensor* Transpose(const Tensor* array, map permutation) { - ShapeInfo shapeinfo = array->GetShapeInfo(); - int old_dim = shapeinfo.dim; - Tensors perm_shape = Tensors(); - int permuted_dim = (int)permutation.size(); - - shapeinfo.ExpandDimensionsTo(permuted_dim); - Tensors shape = shapeinfo.GetTensors(); - - for (int i = 0; i < permuted_dim; i++) { - perm_shape.push_back(shape[permutation[i]]); - } - - //create indices - Tensors indices = Tensors(); - for (int i = 0; i < permuted_dim; i++) { - indices.push_back(&Tensor::Index(perm_shape, i)); - } - //permute indices to load the values - Tensors perm_indices = Tensors(old_dim, nullptr); - for (int i = 0; i < permuted_dim; i++) { - int old = permutation[i]; - if(old < old_dim) { - perm_indices[old] = indices[i]; - } - } - //if any nullptr, then put a constant 0 - for (int i = 0; i < old_dim; i++) { - if(perm_indices[i] == nullptr) { - perm_indices[i] = &Tensor::Constant(0); - } - } - - Tensor& loaded = Tensor::Load(*array, perm_indices, IndexingMode::Unsafe); - loaded.SetShape(perm_shape); //in case perm indices has no shape info (0 dim unsqueeze) - loaded.SetDebugName("transposed"); - return &loaded; -} - -Tensor* ReverseDim(const Tensor* array, int axis) { - ShapeInfo shapeinfo = array->GetShapeInfo(); - int dims = shapeinfo.dim; - Tensors shape = shapeinfo.GetTensors(); - Tensors indices = Tensors(); - for (int i = 0; i < dims; i++) { - if (i == axis) { - indices.push_back(&(*shape[i] - Tensor::Constant(1) - Tensor::Index(shape, i))); - } else { - indices.push_back(&Tensor::Index(shape, i)); - } - } - Tensor& loaded = Tensor::Load(*array, indices, IndexingMode::Unsafe); - loaded.SetDebugName("reversed"); - return &loaded; -} - -Tensor* SplitDim(const Tensor* array, const Tensor* splitted, int axis, int split_size) { - ShapeInfo shapeinfo = array->GetShapeInfo(); - int dims = shapeinfo.dim; - Tensors new_shape = splitted->GetShape(); - Tensors indices = Tensors(); - for (int i = 0; i < dims; i++) { - if (i == axis) { - Tensor* index1 = &Tensor::Index(new_shape, i); - Tensor* index2 = &Tensor::Index(new_shape, i + 1); - //merged index - indices.push_back(&(*index2 * (*new_shape[i]) + *index1)); - } else if(i < axis) { - indices.push_back(&Tensor::Index(new_shape, i)); - } else { - indices.push_back(&Tensor::Index(new_shape, i + 1)); - } - } - Tensor* loaded = ConstantOutOfBounds(array, indices, 0); - loaded->SetDebugName("split"); - return loaded; -} - -Tensor* MergeDim(const Tensor* array, const Tensor* merged, int axis) { - ShapeInfo shapeinfo = array->GetShapeInfo(); - int dims = shapeinfo.dim; - axis = GetAxis(dims, axis); - Tensors shape = shapeinfo.GetTensors(); - Tensors new_shape = merged->GetShape(); - Tensors indices = Tensors(); - for (int i = 0; i < dims-1; i++) { - if (i == axis) { - Tensor* merged_index = &Tensor::Index(new_shape, i); - //get split index - indices.push_back(&(*merged_index % *shape[axis])); - indices.push_back(&(*merged_index / *shape[axis])); - } else { - indices.push_back(&Tensor::Index(new_shape, i)); - } - } - Tensor* loaded = ConstantOutOfBounds(array, indices, 0); - loaded->SetDebugName("merged"); - return loaded; -} - -Tensor* ComputeDot(const Tensor* a, const Tensor* b, int axis) { - Tensors shape_a = a->GetShape(); - Tensors shape_b = b->GetShape(); - axis = GetAxis((int)shape_a.size(), axis); - return ComputeSum(&(*a * *b), axis); -} - -//compute the matrix multiplication of two last dimensions -//takes two tensors [T1, T2, ..., Tn, M, N] and [Tm, .., Tn, N, K] and returns [T1, T2, ..., Tm, M, K] -Tensor* ComputeMatMul(const Tensor* a, const Tensor* b) { - ShapeInfo shape_a = a->GetShapeInfo(); - ShapeInfo shape_b = b->GetShapeInfo(); - - if (shape_a.dim < 2 && shape_b.dim < 2) { - throw std::runtime_error("Matrix multiplication requires at least one 2D tensor"); - } - - if(shape_a.dim < 2) { - shape_a.ExpandDimensionsTo(2); - } - if(shape_b.dim < 2) { - shape_b.ExpandDimensionsTo(2); - } - - Tensors shape_a_tensors = shape_a.GetTensors(); - Tensors shape_b_tensors = shape_b.GetTensors(); - - //get shape of the result - Tensors shape_c = Tensors(); - int dim_a = shape_a.dim; - int dim_b = shape_b.dim; - int max_dim = 0; - Tensors max_shape = Tensors(); - //get the shape with most dimensions - if (dim_a < dim_b) { - max_dim = dim_b; - max_shape = shape_b_tensors; - } else { - max_dim = dim_a; - max_shape = shape_a_tensors; - } - - shape_c.push_back(shape_b_tensors[0]); - shape_c.push_back(shape_a_tensors[1]); - for (int i = 2; i < max_dim; i++) { - shape_c.push_back(max_shape[i]); - } - ShapeDimCompareResult result = CompareShapeDim(shape_a_tensors[0]->node_, shape_b_tensors[1]->node_); - if (!result.compatible) { - throw std::runtime_error("Inner dimensions of the matrices must match"); - } - - const Tensor* sum_shape = result.broadcast_dim->GetTensor(); - - // get indices for c elements - Tensors indices_c = Tensors(); - for (int i = 0; i < max_dim; i++) { - indices_c.push_back(&Tensor::Index(shape_c, i)); - } - - // start with 0 - Tensor* c = &Tensor::Constant(shape_c, 0, a->node_->format); - c->SetDebugName("matmul"); - - // loop over k and compute += A t1t2..tN ik * B t1t2..tN kj - Tensor::Loop(Tensor::Constant(0), *sum_shape, Tensor::Constant(1), - [&](const Tensor& k) { - - // get indices for a elements - Tensors indices_a = Tensors(); - - indices_a.push_back(&k); - indices_a.push_back(indices_c[1]); - for (int i = 2; i < dim_a; i++) { - indices_a.push_back(indices_c[i]); - } - - // get indices for b elements - Tensors indices_b = Tensors(); - - indices_b.push_back(indices_c[0]); - indices_b.push_back(&k); - for (int i = 2; i < dim_b; i++) { - indices_b.push_back(indices_c[i]); - } - - Tensor& a_val = Tensor::Load(*a, indices_a, IndexingMode::Unsafe); a_val.node_->flags.set(NodeProp::NoLoadFusion); // disable load fusion for now - Tensor& b_val = Tensor::Load(*b, indices_b, IndexingMode::Unsafe); a_val.node_->flags.set(NodeProp::NoLoadFusion); - - Tensor* prod = &(a_val * b_val); - - c->Set(*c + *prod); - }); - - return c; -} - -map implementation_functions = -{ - {"dim_sum", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ComputeSum(inputs[0],axes[0])); }}, - {"dim_norm", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ComputeNorm(inputs[0],axes[0])); }}, - {"dim_max", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ComputeMax(inputs[0],axes[0])); }}, - {"dim_min", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ComputeMin(inputs[0],axes[0])); }}, - {"dim_mean", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ComputeMean(inputs[0],axes[0])); }}, - {"dim_product", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ComputeProduct(inputs[0],axes[0])); }}, - {"dim_any", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ComputeAny(inputs[0],axes[0])); }}, - {"dim_all", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ComputeAll(inputs[0],axes[0])); }}, - {"dim_prefix_sum", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ComputePrefixSum(inputs[0],axes[0])); }}, - {"transpose", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { - //get the permutation - int dim = (int)inputs[0]->GetDimension(); - dim = std::max(dim, std::max(axes[0], axes[1]) + 1); - map permutation; - for (int i = 0; i < dim; i++) { - if(i == axes[0]) { - permutation[i] = axes[1]; - } else if(i == axes[1]) { - permutation[i] = axes[0]; - } else { - permutation[i] = i; - } - } - outputs.push_back(Transpose(inputs[0], permutation)); - }}, - {"dot", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ComputeDot(inputs[0], inputs[1], axes[0])); }}, - {"matmul", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ComputeMatMul(inputs[0], inputs[1])); }}, - {"unsqueeze", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { - map permutation; - int dim = (int)inputs[0]->GetDimension()+1; - dim = std::max(dim, axes[0] + 1); - for(int i = 0; i < dim; i++) { - if(i == axes[0]) { - permutation[i] = dim-1; - } else if (i < axes[0]) { - permutation[i] = i; - } else { - permutation[i] = i - 1; - } - } - outputs.push_back(Transpose(inputs[0], permutation)); - }}, - {"squeeze", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { - map permutation; - int dim = (int)inputs[0]->GetDimension() - 1; - for(int i = 0; i < dim; i++) { - if(i < axes[0]) { - permutation[i] = i; - } else { - permutation[i] = i + 1; - } - } - outputs.push_back(Transpose(inputs[0], permutation)); - }}, - {"dim_reverse", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ReverseDim(inputs[0], axes[0])); }}, - {"dim_split", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(SplitDim(inputs[0], tensor, axes[0], axes[1])); }}, - {"dim_merge", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(MergeDim(inputs[0], tensor, axes[0])); }}, - {"dim_repeat", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { - const Tensor* input_tensor = inputs[0]; - Tensors shape = input_tensor->GetShape(); - Tensors new_shape = tensor->GetShape(); - Tensors indices = Tensors(); - for (int i = 0; i < (int)new_shape.size(); i++) { - indices.push_back(&Tensor::Index(new_shape, i)); - } - Tensor* loaded = &Tensor::Load(*input_tensor, indices, IndexingMode::Repeat); - loaded->SetDebugName("repeated"); - outputs.push_back(loaded); - }}, - {"smoothstep", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { - const Tensor& x = *inputs[2]; - const Tensor& edge0 = *inputs[0]; - const Tensor& edge1 = *inputs[1]; - Tensor& x1 = (x - edge0) / (edge1 - edge0); - Tensor& t = Tensor::clamp(x1, Tensor::Constant(0.0f), Tensor::Constant(1.0f)); - Tensor& result = t * t * (Tensor::Constant(3.0f) - Tensor::Constant(2.0f) * t); - outputs.push_back(&result); - }}, -}; - -ImplementationFunction GetImplementationForOperation(string name) { - if (!implementation_functions.contains(name)) { - throw std::runtime_error("Cannot compute implementation for operation " + name); - } - return implementation_functions[name]; -} - -void RegisterImplementation(string name, ImplementationFunction impl) { - if (implementation_functions.contains(name)) { - throw std::runtime_error("Implementation for operation " + name + " already exists"); - } - implementation_functions[name] = impl; -} - - -map algorithm_vjps = {}; - -AlgorithmVJPGradientFunction GetAlgorithmVJPForOperation(string name) { - if (!algorithm_vjps.contains(name)) { - throw std::runtime_error("Cannot compute VJP for operation " + name); - } - return algorithm_vjps[name]; -} - -void RegisterAlgorithmVJP(string name, AlgorithmVJPGradientFunction vjp) { - if (algorithm_vjps.contains(name)) { - throw std::runtime_error("VJP for operation " + name + " already registered"); - } - algorithm_vjps[name] = vjp; -} - -VJPGradientFunction CreateAlgorithmVJP(const string& name) { - VJPGradientFunction vjp = [name](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - auto inputs = in.GetTensors(ArgType::Input); - AlgorithmVJPGradientFunction impl = GetAlgorithmVJPForOperation(name); - Tensors grad_tensors = impl(inputs, &grad, &out); - for (int i = 0; i < (int)grad_tensors.size(); i++) { - grads.Add(ArgType::Input, i, *const_cast(grad_tensors[i])); - } - }; - return vjp; -} - -void RegisterAlgorithmicPrimitive(const string& name, vector overloads, ImplementationFunction impl, AlgorithmVJPGradientFunction vjp) { - Operation* newop = new Operation(name, overloads, 0, "", {OpProp::Custom, OpProp::Algorithm}); - RegisterNewOperation(newop); - RegisterImplementation(name, impl); - RegisterAlgorithmVJP(name, vjp); - RegisterVJP(name, CreateAlgorithmVJP(name)); -} - -} // namespace TensorFrost - - diff --git a/TensorFrost/Compiler/Implementations.h b/TensorFrost/Compiler/Implementations.h deleted file mode 100644 index 9f221f7e..00000000 --- a/TensorFrost/Compiler/Implementations.h +++ /dev/null @@ -1,112 +0,0 @@ -#pragma once - -#include "Backend/TensorMemory.h" -#include "Tensor/Tensor.h" - -namespace TensorFrost { - -const Tensor& ReduceGradientToShape(const Tensor& gradient, const Tensor& target); -int GetGradAxis(const Tensor& out, const Tensor& grad); - -class NodeGrads -{ - bool stop_fusion = true; - unordered_map stored_gradients; - unordered_map arguments; - unordered_map argument_inputs; -public: - //get element at index - const Tensor& operator[](ArgID id) { - return *stored_gradients[arguments[id]]; - } - - bool Contains(ArgID id) { - return stored_gradients.contains(arguments[id]); - } - - bool Contains(ArgType type, int index = 0) { - return Contains(ArgID(type, index)); - } - - NodeGrads(Node* node, map input_grads) { - try { - for(auto& [id, input] : node->args.Inputs()) { - if (id.first == ArgType::Index || id.first == ArgType::Shape) { - continue; - } - argument_inputs[id] = input->GetTensor(); - arguments[id] = input; - if(input_grads.contains(input)) { - stored_gradients[input] = &ReduceGradientToShape(*input_grads[input], *input->GetTensor()); - } - } - } catch (const std::exception& e) { - throw std::runtime_error("Error in gradient initialization: " + string(e.what())); - } - } - - void Add(ArgType type, int index, const Tensor& tensor) { - ArgID id = ArgID(type, index); - const Tensor* target = argument_inputs[id]; - try { - Tensor& new_tensor = const_cast(ReduceGradientToShape(tensor, *target)); - if(Contains(type, index)) { - auto& old_tensor = *stored_gradients[arguments[id]]; - if(stop_fusion) old_tensor.StopFusion(); - auto& loaded = old_tensor; //TODO: temporary way to restrict fusion, remove after implementing split - stored_gradients[arguments[id]] = &(loaded + new_tensor); - } else { - stored_gradients[arguments[id]] = &new_tensor; - } - } catch (const std::exception& e) { - throw std::runtime_error("Error in gradient addition: " + string(e.what())); - } - } - - const Tensor* GetGrad(ArgID id) { - if(Contains(id)) { - return stored_gradients[arguments[id]]; - } else { - IR* cur_ir = Tensor::GetEvaluationContext(); - const Tensor* input = argument_inputs[id]; - Tensor* zero_grad = nullptr; - cur_ir->ExecuteExpressionAfter(input->node_, [&]() { - zero_grad = &Tensor::Constant(argument_inputs[id]->GetShape(), 0.0f); - stored_gradients[arguments[id]] = zero_grad; - }); - return zero_grad; - } - } - - const Tensor* GetGrad(ArgType type, int index) { - return GetGrad(ArgID(type, index)); - } - - //add gradients to inputs - template - void Add(const Tensor& arg, Args&... args) { - //by default these are ArgType::Input - vector inputs = vector({ &arg, &args... }); - for (int i = 0; i < inputs.size(); i++) { - Add(ArgType::Input, i, *inputs[i]); - } - } -}; - - -typedef function VJPGradientFunction; -typedef function inputs, const Tensor* gradient, const Tensor* tensor)> AlgorithmVJPGradientFunction; - -VJPGradientFunction GetVJPForOperation(string name); -void RegisterVJP(string name, VJPGradientFunction vjp); -bool HasDerivativeImplemented(string name); -//TODO JVPGradientFunction for forward mode autodiff - -typedef function inputs, const Tensor* tensor, vector axes)> ImplementationFunction; - -ImplementationFunction GetImplementationForOperation(string name); -void RegisterImplementation(string name, ImplementationFunction impl); - -void RegisterAlgorithmicPrimitive(const string& name, vector overloads, ImplementationFunction impl, AlgorithmVJPGradientFunction vjp); - -} diff --git a/TensorFrost/Compiler/KernelGen.cpp b/TensorFrost/Compiler/KernelGen.cpp deleted file mode 100644 index 4c0cd3e1..00000000 --- a/TensorFrost/Compiler/KernelGen.cpp +++ /dev/null @@ -1,74 +0,0 @@ -#include "Compiler/KernelGen.h" - -namespace TensorFrost { - -Program* GenerateProgram(IR* ir) -{ - ir->CompileIR(); - - auto* program = new Program(ir); - - vector kernels = ir->GetNodesOfType("kernel"); - - for (auto kernel : kernels) - { - // get the kernel type - map variables; - map read_write; - set group_memory; - NodeArguments shape = kernel->args.GetArguments(ArgType::Shape); - size_t variable_index = 0; - - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - if(node->name == "group_memory") { - group_memory.insert(*node); - continue; - } - if (node->op->HasAllTypes(OpProp::MemoryOp)) { - //if the memory is inside of this kernel - skip node - if (node->flags.has(NodeProp::LocalMemoryOp)) { - continue; - } - - // get the memory node - const Tensor* memory = node->args.GetTensor(ArgType::Memory); - - if(node->op->HasAllTypes(OpProp::Modifier)) { - read_write[memory->node_] |= true; - } else { - read_write[memory->node_] |= false; - } - } - - // get all input arguments - for (auto [id, from] : node->args.Inputs()) { - if (id.first == ArgType::Input) - { - bool from_outside_kernel = !from->HasParent(kernel); - if (from_outside_kernel && !variables.contains(from)) { - variables[from] = variable_index++; - } - } - } - } - - map read_write_memory; - map read_only_memory; - size_t read_write_index = 0; - size_t read_only_index = 0; - for(auto [node, rw] : read_write) { - if(rw) { - read_write_memory[node] = read_write_index++; - } else { - read_only_memory[node] = read_only_index++; - } - } - - // add the kernel to the program - program->AddKernel(kernel, variables, read_write_memory, read_only_memory, group_memory, shape); - } - - return program; -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Compiler/KernelGen.h b/TensorFrost/Compiler/KernelGen.h deleted file mode 100644 index 30b3c009..00000000 --- a/TensorFrost/Compiler/KernelGen.h +++ /dev/null @@ -1,80 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -#include "Backend/TensorMemory.h" -#include "Tensor/Tensor.h" -#include "Compiler/Implementations.h" - -namespace TensorFrost { - -uint GetInitialMax(TFType type); -uint GetInitialMin(TFType type); - -class Kernel { - public: - Node* root; - map variables; - map read_write_memory; - map read_only_memory; - set group_memory; - NodeArguments shape; - - size_t kernel_id_; - string kernel_name_; - string full_generated_code_; - string generated_header_; - string generated_bindings_; - string generated_main_; - - vector var_names; - vector var_types; - - map GetMemoryBindings() { - map result; - for (auto& mem : read_write_memory) { - result[mem.first] = mem.second; - } - for (auto& mem : read_only_memory) { - result[mem.first] = mem.second + read_write_memory.size(); - } - return result; - } -}; - -class Program { - public: - IR* ir_; - vector kernels_; - function unload_callback; - string generated_code_; - string main_function_; - string program_name = "TensorProgram"; - - float last_execution_time = 0.0f; - float shader_compile_time = 0.0f; - float host_compile_time = 0.0f; - - function execute_callback; - - explicit Program(IR* ir) : ir_(ir) {} - - void AddKernel(Node* kernel_node, map variables, map read_write, map read_only, set group_memory, - NodeArguments shape) - { - kernels_.push_back( - {kernel_node, std::move(variables), std::move(read_write), std::move(read_only), std::move(group_memory), std::move(shape)}); - } -}; - -Program* GenerateProgram(IR* ir); - -bool isConstantAndEqualTo(const Tensor* tensor, float value); -bool isConstant(const Tensor* tensor); -Tensor* ApplyMultiOP(const Tensor* a, const Tensor* b, std::function opF32, std::function opI32, std::function opU32); -Tensor* ApplyUnaryOP(const Tensor* a, std::function opF32, std::function opI32, std::function opU32); - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Compiler/Operations.cpp b/TensorFrost/Compiler/Operations.cpp deleted file mode 100644 index 34091587..00000000 --- a/TensorFrost/Compiler/Operations.cpp +++ /dev/null @@ -1,226 +0,0 @@ -#include "Operations.h" -#include - -namespace TensorFrost { - -unordered_map type_names = { - {TFType::None, "void"}, {TFType::Bool, "bool"}, {TFType::Float, "float"}, - {TFType::Uint, "uint"}, {TFType::Int, "int"}, -}; - -std::unordered_map DataTypeNames = { - {TFType::Float, "Float"}, {TFType::Uint, "Uint"}, - {TFType::Int, "Int"}, {TFType::Bool, "Bool"}, - {TFType::None, "None"}, -}; - -std::map DataFormatNames = { - {TFTypeFloat32, "TFTypeFloat32"}, {TFTypeInt32, "TFTypeInt32"}, - {TFTypeUint32, "TFTypeUint32"}, {TFTypeBool32, "TFTypeBool32"}, - {TFTypeNone, "TFTypeNone"}, -}; - -const vector operations = { - //Scope operations - Operation("host", {""}, 0, "", {OpProp::Static, OpProp::Special, OpProp::HostOnly, OpProp::Nondiff, OpProp::HasChildren}), - Operation("kernel", {""}, 0, "", {OpProp::Static, OpProp::Special, OpProp::HostOnly, OpProp::Nondiff, OpProp::HasChildren}), - - //Control operations - Operation("loop", {"iii_i"}, 100, "", {OpProp::Static, OpProp::Special, OpProp::Nondiff, OpProp::HasChildren}), - Operation("if", {"b_"}, 100, "", {OpProp::Static, OpProp::Special, OpProp::Nondiff, OpProp::HasChildren}), - Operation("break", {""}, 0, "break", {OpProp::Static, OpProp::Nondiff}, OpClass::Keyword), - Operation("continue", {""}, 0, "continue", {OpProp::Static, OpProp::Nondiff}, OpClass::Keyword), - Operation("discard", {""}, 0, "discard", {OpProp::Static, OpProp::Nondiff}, OpClass::Keyword), //discard current thread - Operation("group_barrier", {""}, 256, "", {OpProp::Static, OpProp::KernelOnly}), - - //Allocation operations - Operation("memory", {"_f", "_i", "_u", "_b"}, 0, "", {OpProp::Memory, OpProp::Special, OpProp::HostOnly, OpProp::Nondiff}), - Operation("reshape", {"_f", "_i", "_u", "_b"}, 0, "", {OpProp::Memory, OpProp::Special, OpProp::HostOnly, OpProp::MemoryReuse}), - Operation("assert", {"_f", "_i", "_u", "_b"}, 0, "assert_tensor", {OpProp::Memory, OpProp::Special, OpProp::HostOnly, OpProp::MemoryReuse}), - Operation("input_shape", {"_i"}, 0, "", {OpProp::Special, OpProp::Static, OpProp::HostOnly, OpProp::Nondiff}), - Operation("deallocate", {""}, 0, "", {OpProp::Memory, OpProp::Special, OpProp::HostOnly, OpProp::Nondiff}), - Operation("group_memory", {"_f", "_i", "_u", "_b"}, 0, "", {OpProp::Memory, OpProp::LocalMemory, OpProp::Special, OpProp::KernelOnly}), - Operation("local_memory", {"_f", "_i", "_u", "_b"}, 0, "", {OpProp::Memory, OpProp::LocalMemory, OpProp::Special, OpProp::KernelOnly}), - - Operation("region_begin", {""}, 0, "", {OpProp::Special, OpProp::Static, OpProp::HostOnly, OpProp::Nondiff, OpProp::Debug}), - Operation("region_end", {""}, 0, "", {OpProp::Special, OpProp::Static, OpProp::HostOnly, OpProp::Nondiff, OpProp::Debug}), - Operation("print_value", {"f_", "i_", "u_", "b_"}, 0, "", {OpProp::Special, OpProp::Static, OpProp::HostOnly, OpProp::Nondiff, OpProp::Debug}), - Operation("assert_value", {"b_"}, 0, "", {OpProp::Special, OpProp::Static, OpProp::HostOnly, OpProp::Nondiff, OpProp::Debug}), - - //Algorithms - //Reduction - Operation("dim_sum", {"f_f", "u_u", "i_i"}, 0, "", {OpProp::Algorithm, OpProp::Reduction}), - Operation("dim_norm", {"f_f", "u_u", "i_i"}, 0, "", {OpProp::Algorithm}), - Operation("dim_max", {"f_f", "u_u", "i_i"}, 0, "", {OpProp::Algorithm, OpProp::Reduction}), - Operation("dim_min", {"f_f", "u_u", "i_i"}, 0, "", {OpProp::Algorithm, OpProp::Reduction}), - Operation("dim_mean", {"f_f", "u_u", "i_i"}, 0, "", {OpProp::Algorithm, OpProp::Reduction}), - Operation("dim_prod", {"f_f", "u_u", "i_i"}, 0, "", {OpProp::Algorithm, OpProp::Reduction}), - Operation("dim_any", {"u_u", "i_i", "b_b"}, 0, "", {OpProp::Algorithm, OpProp::Nondiff, OpProp::Reduction}), - Operation("dim_all", {"u_u", "i_i", "b_b"}, 0, "", {OpProp::Algorithm, OpProp::Nondiff, OpProp::Reduction}), - //Scan - Operation("dim_prefix_sum", {"f_f", "u_u", "i_i"}, 0, "", {OpProp::Algorithm, OpProp::Scan}), - //Matrix - Operation("transpose", {"f_f", "u_u", "i_i", "b_b"}, 0, "", {OpProp::Algorithm}), - Operation("dot", {"ff_f"}, 0, "", {OpProp::Algorithm}), // dot product of the last dimensions - Operation("matmul", {"ff_f"}, 0, "", {OpProp::Algorithm}), // matrix multiplication of the last dimensions - - //Other - Operation("dim_reverse", {"f_f", "u_u", "i_i", "b_b"}, 0, "", {OpProp::Algorithm}), - Operation("dim_repeat", {"f_f", "u_u", "i_i", "b_b"}, 0, "", {OpProp::Algorithm}), - // Operation("dim_concat", {"ff_f", "uu_u", "ii_i", "bb_b"}, 0, "", {OpProperty::Algorithm}), - Operation("dim_split", {"f_f", "u_u", "i_i", "b_b"}, 0, "", {OpProp::Algorithm}), - Operation("dim_merge", {"f_f", "u_u", "i_i", "b_b"}, 0, "", {OpProp::Algorithm}), - // Operation("dim_pad", {"f_f", "u_u", "i_i", "b_b"}, 0, "", {OpProperty::Algorithm}), - Operation("unsqueeze", {"f_f", "u_u", "i_i", "b_b"}, 0, "", {OpProp::Algorithm}), - Operation("squeeze", {"f_f", "u_u", "i_i", "b_b"}, 0, "", {OpProp::Algorithm}), - - Operation("smoothstep", {"fff_f"}, 10, "", {OpProp::Algorithm}), - - //Autodiff - Operation("backwards_grad", {"ff_f"}, 0, "", {OpProp::Gradient}), - Operation("forward_grad", {"ff_f"}, 0, "", {OpProp::Gradient}), - - // Memory operations - Operation("load", {"_f", "_u", "_i", "_b"}, 128, "", - {OpProp::Load, OpProp::MemoryOp}), - Operation("store", {"f_", "u_", "i_", "b_"}, 128, "", - {OpProp::Store, OpProp::MemoryOp, OpProp::Modifier}), - Operation("set", {"f_", "u_", "i_", "b_"}, 1, "", - {OpProp::Set, OpProp::Modifier}), - Operation("InterlockedAdd", {"u_", "i_", "f_"}, 256, "", - {OpProp::Scatter, OpProp::MemoryOp, OpProp::Modifier}), - Operation("InterlockedMin", {"u_", "i_", "f_"}, 256, "", - {OpProp::Scatter, OpProp::MemoryOp, OpProp::Modifier}), - Operation("InterlockedMax", {"u_", "i_", "f_"}, 256, "", - {OpProp::Scatter, OpProp::MemoryOp, OpProp::Modifier}), - Operation("InterlockedAnd", {"u_", "i_"}, 256, "", - {OpProp::Scatter, OpProp::MemoryOp, OpProp::Modifier, OpProp::Nondiff}), - Operation("InterlockedOr", {"u_", "i_"}, 256, "", - {OpProp::Scatter, OpProp::MemoryOp, OpProp::Modifier, OpProp::Nondiff}), - Operation("InterlockedXor", {"u_", "i_"}, 256, "", - {OpProp::Scatter, OpProp::MemoryOp, OpProp::Modifier, OpProp::Nondiff}), - Operation("InterlockedAdd_Prev", {"u_u", "i_i", "f_f"}, 256, "", - {OpProp::Scatter, OpProp::MemoryOp, OpProp::Modifier, OpProp::CantSubstitute, OpProp::Nondiff}), - - // Index operations - Operation("dim_id", {"_i"}, 0, "dim", {OpProp::Nondiff}, OpClass::DimensionIndex), - Operation("block_thread_id", {"_i"}, 0, "", {OpProp::Nondiff}, OpClass::DimensionIndex), - Operation("block_id", {"_i"}, 0, "", {OpProp::Nondiff}, OpClass::Variable), - - //Compute operations - Operation("copy", {"f_f", "u_u", "i_i", "b_b"}, 1, "", {}, OpClass::Copy), //TODO: make sure no one copies memory objects - Operation("add", {"ff_f", "uu_u", "ii_i"}, 1, "+", {}, OpClass::Operator), - Operation("sub", {"ff_f", "uu_u", "ii_i"}, 1, "-", {}, OpClass::Operator), - Operation("mul", {"ff_f", "uu_u", "ii_i"}, 1, "*", {}, OpClass::Operator), - Operation("div", {"ff_f", "uu_u", "ii_i"}, 2, "/", {}, OpClass::Operator), - Operation("mod", {"ff_f", "uu_u", "ii_i"}, 4, "%", {OpProp::Nondiff}, OpClass::Operator), - Operation("lshift", {"uu_u", "ui_u", "ii_i"}, 1, "<<", {OpProp::Nondiff}, OpClass::Operator), - Operation("rshift", {"uu_u", "ui_u", "ii_i"}, 1, ">>", {OpProp::Nondiff}, OpClass::Operator), - Operation("and", {"uu_u", "ii_i", "bb_b"}, 1, "&", {OpProp::Nondiff}, OpClass::Operator), - Operation("or", {"uu_u", "ii_i", "bb_b"}, 1, "|", {OpProp::Nondiff}, OpClass::Operator), - Operation("xor", {"uu_u", "ii_i", "bb_b"}, 1, "^", {OpProp::Nondiff}, OpClass::Operator), - Operation("eq", {"ff_b", "uu_b", "ii_b", "bb_b", "ui_b", "iu_b"}, 1, "==", {OpProp::Nondiff}, OpClass::Operator), - Operation("neq", {"ff_b", "uu_b", "ii_b", "bb_b", "ui_b", "iu_b"}, 1, "!=", {OpProp::Nondiff}, OpClass::Operator), - Operation("lt", {"ff_b", "uu_b", "ii_b", "bb_b", "ui_b", "iu_b"}, 1, "<", {OpProp::Nondiff}, OpClass::Operator), - Operation("lte", {"ff_b", "uu_b", "ii_b", "bb_b", "ui_b", "iu_b"}, 1, "<=", { OpProp::Nondiff}, OpClass::Operator), - Operation("gt", {"ff_b", "uu_b", "ii_b", "bb_b", "ui_b", "iu_b"}, 1, ">", {OpProp::Nondiff}, OpClass::Operator), - Operation("gte", {"ff_b", "uu_b", "ii_b", "bb_b", "ui_b", "iu_b"}, 1, ">=", {OpProp::Nondiff}, OpClass::Operator), - Operation("notb", {"b_b"}, 1, "!", {OpProp::Nondiff}, OpClass::UnaryOperator), - Operation("not", {"u_u", "i_i"}, 1, "~", {OpProp::Nondiff}, OpClass::UnaryOperator), - Operation("neg", {"f_f", "u_u", "i_i"}, 1, "-", {}, OpClass::UnaryOperator), - Operation("uint", {"f_u", "u_u", "i_u", "b_u"}, 1, "uint", {OpProp::Nondiff}, OpClass::TypeCast), - Operation("int", {"f_i", "u_i", "i_i", "b_i"}, 1, "int", {OpProp::Nondiff}, OpClass::TypeCast), - Operation("float", {"f_f", "u_f", "i_f", "b_f"}, 1, "float", {OpProp::Nondiff}, OpClass::TypeCast), - Operation("bool", {"f_b", "u_b", "i_b", "b_b"}, 1, "bool", {OpProp::Nondiff}, OpClass::TypeCast), - Operation("asuint", {"f_u", "u_u", "i_u", "b_u"}, 0, "asuint", - {OpProp::Nondiff}, OpClass::TypeReinterpret), - Operation("asint", {"f_i", "u_i", "i_i", "b_i"}, 0, "asint", - {OpProp::Nondiff}, OpClass::TypeReinterpret), - Operation("asfloat", {"f_f", "u_f", "i_f", "b_f"}, 0, "asfloat", - {OpProp::Nondiff}, OpClass::TypeReinterpret), - Operation("asbool", {"f_b", "u_b", "i_b", "b_b"}, 0, "asbool", - {OpProp::Nondiff}, OpClass::TypeReinterpret), - Operation("min", {"ff_f", "uu_u", "ii_i"}, 1), - Operation("max", {"ff_f", "uu_u", "ii_i"}, 1), - Operation("abs", {"f_f", "u_u", "i_i"}, 1), - Operation("sign", {"f_f", "i_i"}, 1), - Operation("ceil", {"f_f"}, 1), - Operation("floor", {"f_f"}, 1), - Operation("round", {"f_f"}, 1), - Operation("frac", {"f_f"}, 1), - Operation("exp", {"f_f"}, 32), - Operation("exp2", {"f_f"}, 16), - Operation("log", {"f_f"}, 32), - Operation("log2", {"f_f"}, 16), - Operation("sqrt", {"f_f"}, 4), - Operation("rsqrt", {"f_f"}, 2), - Operation("rcp", {"f_f"}, 2), - Operation("sin", {"f_f"}, 2), - Operation("cos", {"f_f"}, 2), - Operation("tan", {"f_f"}, 2), - Operation("asin", {"f_f"}, 8), - Operation("acos", {"f_f"}, 8), - Operation("atan", {"f_f"}, 8), - Operation("sinh", {"f_f"}, 8), - Operation("cosh", {"f_f"}, 8), - Operation("tanh", {"f_f"}, 8), - Operation("pcg", {"u_u"}, 32), - Operation("reversebits", {"i_i", "u_u"}, 8), - Operation("pcgf", {"u_f"}, 34, "", {OpProp::Nondiff}), - Operation("pow", {"ff_f"}, 6), - Operation("atan2", {"ff_f"}, 32), - Operation("modf", {"ff_f"}, 2), - Operation("step", {"ff_f"}, 2), - Operation("clamp", {"fff_f", "uuu_u", "iii_i"}, 4), - Operation("lerp", {"fff_f"}, 4), - Operation("fma", {"fff_f"}, 1), - Operation("ternary", {"bff_f", "buu_u", "bii_i", "bbb_b"}, 4, "", {}, OpClass::TernaryOperator), - Operation("const", {"_f", "_u", "_i", "_b"}, 0, "", {OpProp::Nondiff}, OpClass::Constant), -}; - -void RegisterNewOperation(unordered_map& operation_map, const Operation* op) { - if (operation_map.contains(op->name_)) { - throw runtime_error("Operation already exists: " + op->name_); - } - operation_map[op->name_] = op; -} - -unordered_map CreateOperationMap() { - unordered_map operation_map; - for (const auto& op : operations) { - RegisterNewOperation(operation_map, &op); - } - return operation_map; -} - -unordered_map operation_map = CreateOperationMap(); - -void RegisterNewOperation(const Operation* op) { - RegisterNewOperation(operation_map, op); -} - -DataTypeList Types(initializer_list elements) { - return DataTypeList(elements); -} - -const Operation* FindOperation(const string& name) { - if (name == "") { - throw runtime_error("Operation name is empty"); - } - - auto it = operation_map.find(name); - if (it != operation_map.end()) { - return it->second; - } - - throw runtime_error("IR Operation not defined: " + name); -} - -string DataTypeToString(TFType type) { return type_names[type]; } - -string RemoveSpaces(string str) { - str.erase(remove(str.begin(), str.end(), ' '), str.end()); - return str; -} - -} // namespace TensorFrost diff --git a/TensorFrost/Compiler/Operations.h b/TensorFrost/Compiler/Operations.h deleted file mode 100644 index a6fd6f1f..00000000 --- a/TensorFrost/Compiler/Operations.h +++ /dev/null @@ -1,284 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "Utility/Utility.h" - -namespace TensorFrost { - -using namespace std; - -extern "C" { - enum TFType { - Float, - Uint, - Int, - Bool, - None, - }; - - struct TFDataFormat { - TFType type; - size_t size; - - bool operator==(const TFDataFormat& other) const { - return type == other.type && size == other.size; - } - - bool operator!=(const TFDataFormat& other) const { - return !(*this == other); - } - - int GetHash() const { - return (int)type << 16 | (int)size; - } - - bool operator<(const TFDataFormat& other) const { - return GetHash() < other.GetHash(); - } - - bool operator>(const TFDataFormat& other) const { - return GetHash() > other.GetHash(); - } - }; - -#define TFTypeNone TFDataFormat{TFType::None, 0} -#define TFTypeBool32 TFDataFormat{TFType::Bool, 32} -#define TFTypeFloat32 TFDataFormat{TFType::Float, 32} -#define TFTypeInt32 TFDataFormat{TFType::Int, 32} -#define TFTypeUint32 TFDataFormat{TFType::Uint, 32} -} - -extern std::unordered_map DataTypeNames; -extern std::map DataFormatNames; -extern std::unordered_map type_names; - -//op can have only one class -enum class OpClass { - Operator, - UnaryOperator, - Function, - Copy, - Keyword, - DimensionIndex, - Variable, - TypeCast, - TypeReinterpret, - Constant, - TernaryOperator, - None, -}; - -//op can have multiple properties -enum class OpProp { - Load, - Store, - Set, - Scatter, - Special, - Memory, - LocalMemory, - CantSubstitute, - MemoryOp, - LocalMemoryOp, - Static, //can not be removed or copied - HostOnly, - KernelOnly, - Composite, - Algorithm, - Custom, - Reduction, - Scan, - Modifier, - MemoryReuse, - Gradient, - Nondiff, - HasChildren, - Debug, - Count, -}; - -using OpProps = FlagSet; - -using DataTypeList = vector; - -DataTypeList Types(initializer_list elements); - -class Operation { -public: - string name_; - float cost_ = 0.0F; - vector, TFType>> overloads_; - string code_; - //vector op_classes; - OpProps props_; - OpClass class_; - size_t default_size = 32; - - Operation() = default; - - Operation(string name, vector overloads, float cost, - string code = "", initializer_list op_props = {}, OpClass op_class = OpClass::Function) - : name_(std::move(name)){ - if (code.empty()) { - code = name_; - } - - code_ = code; - cost_ = cost; - - //add op types - for (const auto& type : op_props) { - props_.set(type); - } - - class_ = op_class; - - // parse the overloads - // example: "ff_f" means two floats in, one float out, "buf_f" means a bool, - // uint, float in, float out - for (const auto& oload : overloads) { - vector inputs; - TFType output = TFType::None; - bool is_output = false; - - for (const auto& c : oload) { - TFType parsed_type = TFType::None; - switch (c) { - case 'f': - parsed_type = TFType::Float; - break; - case 'u': - parsed_type = TFType::Uint; - break; - case 'i': - parsed_type = TFType::Int; - break; - case 'b': - parsed_type = TFType::Bool; - break; - case '_': - is_output = true; - break; - default: - throw std::runtime_error("Invalid character in overload string"); - break; - } - - if (is_output) { - output = parsed_type; - } else { - inputs.push_back(parsed_type); - } - } - - overloads_.emplace_back(inputs, output); - } - } - - bool HasAllTypes(OpProp type) const { - return props_.has(type); - } - - template - bool HasAllTypes(OpProp type, Args... args) const { - return HasAllTypes(type) && HasAllTypes(args...); - } - - bool HasAnyType(OpProp type) const { - return HasAllTypes(type); - } - - template - bool HasAnyType(OpProp type, Args... args) const { - return HasAllTypes(type) || HasAnyType(args...); - } - - float GetCost() const { return cost_; } - - string GetName() const { return name_; } - - vector, TFType>> GetOverloads() const { - return overloads_; - } - - size_t GetInputCount() const { - return overloads_[0].first.size(); - } - - bool IsOverloadValid(const pair, TFType>& overload, const vector& input_types) const { - if (overload.first.size() != input_types.size()) { - return false; - } - - for (size_t i = 0; i < input_types.size(); i++) { - if (overload.first[i] != input_types[i].type) { - return false; - } - } - - if (input_types.size() == 0) { - return true; - } - - //check if all format sizes are the same - size_t size = input_types[0].size; - for (size_t i = 1; i < input_types.size(); i++) { - if (input_types[i].size != size) { - return false; - } - } - - return true; - } - - bool IsInputValid(const vector& input_types) const { - for (const auto& overload : overloads_) { - if (IsOverloadValid(overload, input_types)) { - return true; - } - } - return false; - } - - bool IsOutputValid(const TFDataFormat& output_type) const { - for (const auto& overload : overloads_) { - if (overload.second == output_type.type) { - return true; - } - } - return false; - } - - TFDataFormat GetOutputType( - const vector& input_types) const { - for (const auto& overload : overloads_) { - if (IsOverloadValid(overload, input_types)) { - size_t cur_size = default_size; - if (input_types.size() > 0) { - cur_size = input_types[0].size; - } - return {overload.second, cur_size}; - } - } - throw std::runtime_error("Invalid input types for operation"); - } -}; - -const Operation* FindOperation(const string& name); - -string DataTypeToString(TFType type); - -string RemoveSpaces(string str); - -void RegisterNewOperation(const Operation* op); - -} // namespace TensorFrost diff --git a/TensorFrost/Compiler/Steps/Algorithms.cpp b/TensorFrost/Compiler/Steps/Algorithms.cpp deleted file mode 100644 index d8bf8034..00000000 --- a/TensorFrost/Compiler/Steps/Algorithms.cpp +++ /dev/null @@ -1,64 +0,0 @@ -#include "Compiler/KernelGen.h" - -namespace TensorFrost { -bool IR::InsertAlgorithmicPrimitives(bool skip_differentiable) { - // get all nodes for each type - vector nodes = GetNodesOfType(OpProp::Algorithm); - - unordered_set nodes_to_remove; - - // replace all nodes with the algorithmic primitive - for (auto node : nodes) { - if(HasDerivativeImplemented(node->name) && skip_differentiable) { - continue; - } - //compute the sum after the node - ExecuteExpressionAfter(node, [&]() { - //get the input tensor - map inputs = node->args.GetTensors(ArgType::Input); - - //get sum axis - vector axes; - for (int i = 0; i < node->data.size(); i++) { - axes.push_back((int)node->data[i]); - } - - Tensors results; - ImplementationFunction func = GetImplementationForOperation(node->name); - -#ifndef NDEBUG - current_function = node->name; -#endif - - func(results, inputs, node->GetTensor(), axes); - -#ifndef NDEBUG - current_function = "None"; -#endif - - const Tensor* result = results[0]; - - //replace the node with the sum - node->ReplaceThisWithGivenNode(result->node_); - - ShapeCompareResult shape_result = CompareShape(node, result->node_); - if (!shape_result.exactly_compatible) { - throw std::runtime_error("Algorithmic primitive " + node->name + " at " + node->debug_name + " has incompatible shapes"); - } - }); - - //mark the node for removal - nodes_to_remove.insert(node); - } - - // remove all nodes that are not used - for (auto* node : nodes_to_remove) { - RemoveNode(node); - } - - UpdateGraph(); - - return nodes_to_remove.empty(); -} - -} // namespace TensorFrost diff --git a/TensorFrost/Compiler/Steps/Autodiff.cpp b/TensorFrost/Compiler/Steps/Autodiff.cpp deleted file mode 100644 index 773839ed..00000000 --- a/TensorFrost/Compiler/Steps/Autodiff.cpp +++ /dev/null @@ -1,154 +0,0 @@ -#include "Compiler/KernelGen.h" - -namespace TensorFrost { - -void ComputeNodeGradients(Node* value, const Tensor* grad, NodeGrads& grads) -{ - try { - string op_name = value->name; - //add input arguments - if(value->flags.has(NodeProp::PassGrad)) { - op_name = "passthrough_grad"; - } - if(value->flags.has(NodeProp::DetachGrad)) { - op_name = "detached_grad"; - } - - VJPGradientFunction gradient_func = GetVJPForOperation(op_name); - - Tensor out = *value->tensor_; - gradient_func(value->args, out, *grad, grads); - } catch (const std::exception& e) { - throw std::runtime_error("Error in gradient computation for " + value->debug_name + "(" + to_string(value->debug_index) + "): " + e.what()); - } -} - -bool IR::ComputeAutodiff() -{ - vector gradients = GetNodesOfType(OpProp::Gradient); - - if(gradients.empty()) { - return true; - } - - set loss_nodes; - map, Node*> loss_wrt_grad; - unordered_map min_range; //index of earliest node required for the gradient, end of backpropagation - - for (auto gradient : gradients) { - Node* loss = gradient->args.Get(ArgType::Input, 0); - Node* wrt = gradient->args.Get(ArgType::Input, 1); - Node* last_loss_version = loss->GetLastVersion(gradient); - - loss_nodes.insert(last_loss_version); - if(!min_range.contains(last_loss_version)) { - min_range[last_loss_version] = wrt->index_; - } else { - min_range[last_loss_version] = std::min(min_range[last_loss_version], wrt->index_); - } - loss_wrt_grad[{last_loss_version, wrt}] = gradient; - } - - map grad_to_computed_grad; - for (auto loss : loss_nodes) { - set visited; - map node_to_grad; - - unordered_set loss_deps = GetDependencies({loss}); - - //get all differentiable nodes that can change the loss - vector queue; - for (auto dep : loss_deps) { - bool in_range = (dep->index_ <= loss->index_ && dep->index_ >= min_range[loss]); - bool dep_is_accessible = dep->HasCommonParents(loss); //is it in scope of the loss - if(in_range && !dep->op->HasAllTypes(OpProp::Nondiff) && - dep_is_accessible && (dep->format.type == TFType::Float || dep->op->HasAllTypes(OpProp::Modifier))) { - queue.push_back(dep); - } - } - - //sort the nodes by index in descending order (backpropagation) - ranges::sort(queue.begin(), queue.end(), [](Node* a, Node* b) { - return a->index_ > b->index_; - }); - - Node* loss_value = loss; - if(loss->op->HasAllTypes(OpProp::Modifier)) { - loss_value = loss->args.Get(ArgType::Memory); - } - - ExecuteExpressionAfter(loss, [&]() { - node_to_grad[loss_value] = &Tensor::Constant(1.0f); - for(auto node : queue) { - if(!node_to_grad.contains(node) && !node->op->HasAllTypes(OpProp::Modifier)) { - continue; - } - - node_to_grad[node] = &ReduceGradientToShape(*node_to_grad[node], *node->GetTensor()); - - NodeGrads grads = NodeGrads(node, node_to_grad); - - #ifndef NDEBUG - current_function = node->name + "_grad"; - #endif - - ComputeNodeGradients(node, node_to_grad[node], grads); - - #ifndef NDEBUG - current_function = "None"; - #endif - - //store the computed gradients - for (auto& [id, input]: node->args.Inputs()) { - if(!grads.Contains(id)) { - continue; - } - - const Tensor& new_grad = *grads.GetGrad(id); - node_to_grad[input] = &new_grad; - - //TODO: maybe add a function to get temp names - if(input->debug_name != "") { - new_grad.SetDebugName("d" + loss_value->debug_name + "_d" + input->debug_name); - } else if(input->var_name != "") { - new_grad.SetDebugName("d" + loss_value->debug_name + "_d" + input->var_name); - } - } - } - }); - - for (auto wrt_grad : loss_wrt_grad) { - if (wrt_grad.first.first != loss) { - continue; - } - - Node* grad = wrt_grad.second; - if(!node_to_grad.contains(wrt_grad.first.second)) { - throw std::runtime_error("Gradient not computed for " + wrt_grad.first.second->var_name); - } - Node* computed_grad = node_to_grad[wrt_grad.first.second]->node_; - grad_to_computed_grad[grad] = computed_grad; - } - } - - unordered_set nodes_to_remove; - //replace all gradients with computed gradients - for (auto gradient : gradients) { - Node* computed_grad = grad_to_computed_grad[gradient]; - //replace the node with the sum - gradient->ReplaceThisWithGivenNode(computed_grad); - - //mark the node for removal - nodes_to_remove.insert(gradient); - } - - for (auto* node : nodes_to_remove) { - RemoveNode(node); - } - - UpdateGraph(); - - return false; -} - -} // namespace TensorFrost diff --git a/TensorFrost/Compiler/Steps/GraphOps.cpp b/TensorFrost/Compiler/Steps/GraphOps.cpp deleted file mode 100644 index 85aef0ff..00000000 --- a/TensorFrost/Compiler/Steps/GraphOps.cpp +++ /dev/null @@ -1,1343 +0,0 @@ -#include "Compiler/KernelGen.h" -#include "Backend/CodeGen/Generators.h" - -namespace TensorFrost { - -void KernelScope::CreateKernel() { - //create kernel node - Tensor& tensor = Tensor::Kernel(scope_shape.GetTensors()); - Node* kernel_node = tensor.node_; - Node* old_child = kernel_node->child; - kernel_node->child = begin; - kernel_node->next = end->next; - begin->parent = kernel_node; - begin->prev = nullptr; - end->next->prev = kernel_node; - end->next = old_child; - old_child->prev = end; -} - -void IR::SeparateOperationsIntoKernels() { - - unordered_set kernel_scopes; - - ExecuteExpressionFirstChild(root, [&]() { - kernel_scopes = KernelScope::ComputeScopes(root).first; - }); - - // create kernel nodes for all kernel scopes - for (auto scope : kernel_scopes) { - // create kernel before the scope - ExecuteExpressionBefore(scope->begin, [&]() { - scope->CreateKernel(); - }); -#ifdef _RELWITHDEBINFO - if (scope->boundary_nodes.size() > max_kernel_memory_dependencies) { - cout << current_pass << ": Warning: Kernel created with " << scope->boundary_nodes.size() << " boundary nodes" << endl; - } -#endif - } - - UpdateGraph(); -} - - -unordered_set IR::ComputeKernelDependencies(Node* kernel) { - unordered_set kernel_deps; - - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - //go over all inputs - unordered_set node_deps; - for (auto& [id, from] : node->args.Inputs()) { - if(id.first == ArgType::Shape) { - continue; - } - bool is_outside = !from->HasParent(kernel); - if (is_outside) { - //check if its not scalar - if(from->args.Count(ArgType::Shape) == 0) { - continue; - } - node_deps.insert(from); - } else { - node_deps.insert(from->memory_deps.begin(), from->memory_deps.end()); - } - } - //go over all outputs and add this node as a dependency if those outputs are outside the kernel - //this means the kernel must also write to an output - for (auto [edge, to] : node->args.Outputs()) { - if(edge.first.first == ArgType::Shape) { - continue; - } - if (!to->HasParent(kernel)) { - node_deps.insert(node.get()); - } - } - node->memory_deps = node_deps; - kernel_deps.insert(node_deps.begin(), node_deps.end()); - } - - kernel->memory_deps = kernel_deps; - return kernel_deps; -} - - -// check if all child nodes in a kernel have compatible shape to the kernel -void IR::CheckKernelShapes() { - // get kernels - vector kernels = GetNodesOfType("kernel"); - - // go over all outputs of each kernel and create memory nodes to store the - // output - for (auto kernel : kernels) { - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - // check if the node has a shape argument - ShapeCompareResult result = CompareShape(kernel, node.get(), true); - } - -#ifdef _RELWITHDEBINFO - auto deps = ComputeKernelDependencies(kernel); - if(deps.size() > max_kernel_memory_dependencies) { - cout << current_pass << ": Warning: Kernel " << kernel->debug_index << " has " << deps.size() << " dependencies" << endl; - } -#endif - } - - - UpdateGraph(); -} - -void IR::UpdateKernelShapes() { - // get kernels - vector kernels = GetNodesOfType("kernel"); - - // go over all outputs of each kernel and create memory nodes to store the - // output - for (auto kernel : kernels) { - NodeArguments kernel_shape = kernel->args.GetArguments(ArgType::Shape); - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - //set the shape of all nodes in the kernel to the kernel shape - node->args.RemoveArguments(ArgType::Shape); - node->args.AddArguments(kernel_shape); - } - } - - UpdateGraph(); -} - -bool IR::LimitKernelMemoryDependencies() { - UpdateGraph(); - vector kernels = GetNodesOfType("kernel"); - - int created_kernels = 0; - - for (auto kernel : kernels) { - unordered_set kernel_deps = ComputeKernelDependencies(kernel); - } - - for (auto kernel : kernels) { - if(kernel->memory_deps.size() <= max_allowed_memory_dependencies) continue; - - //throw std::runtime_error("Kernel " + to_string(kernel->index_) + " has too many memory dependencies (" + to_string(kernel_deps.size()) + " > " + to_string(max_kernel_memory_dependencies) + ")"); - } - - UpdateGraph(); - - return created_kernels == 0; -} - -void IR::UnrollOperations() { - UpdateGraph(); - vector kernels = GetNodesOfType("kernel"); - - Tensor* const_one; - Tensor* const_zero; - - ExecuteExpressionFirstChild(root, [&]() { - const_one = &Tensor::Constant(1, TFTypeInt32); - const_zero = &Tensor::Constant(0, TFTypeInt32); - }); - - vector> nodes_to_store; - for (auto kernel : kernels) { - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - if(node->name == "dim_id" || node->name == "const") { - continue; - } - // check if the node has a shape argument - ShapeCompareResult result = CompareShape(kernel, node.get(), true); - if (!result.exactly_compatible) { - if(!result.unroll_compatible) continue; - ShapeInfo unroll_shape = node->tensor_->GetShapeInfo(); - vector dims = unroll_shape.GetShape(); - vector> unrolled_dims; - for (int broadcast_dim: result.broadcast_dims) { - unrolled_dims.push_back({dims[broadcast_dim], broadcast_dim}); - } - int unrolled_size = 1; - for (int i = 0; i < unrolled_dims.size(); i++) { - unrolled_size *= unrolled_dims[i].first; - } - cout << "Unrolling node " << node->name << " in kernel " << kernel->name << endl; - - ExecuteExpressionBefore(*node, [&] { - //create local memory with size = to size difference - Tensor* local_memory = nullptr; - if(node->format.type != TFType::None) { - local_memory = &Tensor::LocalMemory(unrolled_size, node->format); - } - - //create loops for each broadcasted dimension - Node* last_loop = nullptr; - map loop_indices; - for (int i = 0; i < unrolled_dims.size(); i++) { - Tensor* const_loop_size = &Tensor::Constant(unrolled_dims[i].first, TFTypeInt32); - if(last_loop == nullptr) { - last_loop = Tensor::Loop(*const_zero, *const_loop_size, *const_one).node_; - } else { - ExecuteExpressionFirstChild(last_loop, [&] { - last_loop = Tensor::Loop(*const_zero, *const_loop_size, *const_one).node_; - }); - } - loop_indices[unrolled_dims[i].second] = last_loop; - } - - //go over inputs of the node and replace them with the loop indices if they are dim_id's of the right dimensions - //if the input is a local memory, then load the value from the memory - for (auto& [id, from] : node->args.InputsCopy()) { - if(from->name == "dim_id") { - int dim = from->data[0]; - if(result.broadcast_dims.contains(dim)) { - //replace the input with the loop index - Node* loop_index = loop_indices[dim]; - node->args.UpdateArgument(id, loop_index); - } - } else if(from->op->HasAllTypes(OpProp::LocalMemory)){ - ExecuteExpressionLastChild(last_loop, [&] { - //load the value from the memory - //TODO compute proper index - Tensor* load_value = &Tensor::Load(*from->tensor_, {const_zero}); - node->args.UpdateArgument(id, load_value->node_); - }); - } - } - - //move this node inside the loop - MoveNodeTo(last_loop->GetLastChild(), node.get()); - - if(local_memory != nullptr) { - //replace all outputs of the node with the local memory - node->ReplaceThisWithGivenNode(local_memory->node_); - nodes_to_store.push_back({local_memory->node_, node.get()}); - } - - //TODO special handling of global mem load operations - }); - - } - } - } - - UpdateGraph(); - - for (auto [local_memory, node] : nodes_to_store) { - ExecuteExpressionAfter(node, [&] { - Tensor* store_value = &Tensor::Store(*local_memory->tensor_, *node->tensor_, {const_zero}); - }); - } -} - -void IR::CheckIR(string name, bool check_clustering, bool check_kernels) { -#ifdef NDEBUG - return; -#endif - UpdateGraph(); - - map invalid_nodes; - //check if the IR is clusterized correctly - for (auto node = begin(); !node.end(); node.next()) { - bool identity = node->args.Count(ArgType::Index) == 0; - - Node* prev = node->prev; - - if (prev == nullptr) continue; - - - // go over all inputs - for (auto& [id, input] : node->args.Inputs()) { - Node* to = node.get(); - - // check if inputs are before the node - if (input->index_ >= to->index_ && input->name != "const") { - if (id.first != ArgType::Shape) { - invalid_nodes[to] = "Argument " + TypeToString(id.first) + ":" + - to_string(id.second) + " is after the node"; - } - } - } - } - - string listing = PrintListing(invalid_nodes); - - if (!invalid_nodes.empty()) { - listing += "Step [" + name + "] failed. "; - throw std::runtime_error(listing); - } else { - cout << "Step [" << name << "] completed successfully: \n" << endl; - cout << listing << endl; - } -} - -void IR::ReorderOperations() { - // get kernel data - vector kernels = GetNodesOfType("kernel"); - - for (auto* kernel: kernels) { - unordered_set nodes_to_move; - // go over all nodes in the kernel and check if their inputs can be copied - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - // go over all inputs - for (auto& [id, from] : node->args.Inputs()) { - bool outside_kernel = !from->HasParent(kernel); - if (outside_kernel && !node->args.CannotMoveArgument(id)) { - // if this node is a set and its input is outside of the cluser -> - // move it inside - if (node->op->HasAllTypes(OpProp::Set)) { - nodes_to_move.insert(from); - } - } - } - } - - //TODO (Moroz): do a check on order of the moved nodes - seems to be breaking sometimes - - // move all the nodes that are outside the kernel inside - Node* kernel_begin = kernel->child; - for (auto* node : nodes_to_move) { - MoveNodeTo(kernel_begin, node); - } - } - - UpdateGraph(); -} - - -/// -/// Copy nodes together with their arguments (as far as possible) -/// -/// nodes to copy -/// if given, the indices to use -/// map between the original nodes and the copied nodes -map IR::CopyComputation( - const unordered_set& targets, const unordered_map& indices) { - - // do a depth first search to copy all the nodes required for the targets - // (only if in the same kernel) - set nodes_to_copy; - bool valid = true; - std::function dfs = [&](Node* node) { - if (nodes_to_copy.contains(node)) return; - nodes_to_copy.insert(node); - for (auto& [arg, from] : node->args.Inputs()) { - if (node->args.CannotCopyArgument(arg)) { - continue; - } - dfs(from); - } - }; - - for (Node* target : targets) { - dfs(target); - } - - - return CopyNodes(nodes_to_copy, {}, indices, targets, true); -} - -map IR::CopyNodesWithIndex(unordered_set nodes_to_copy, - unordered_map indices, - Node *cursor) { - // copy all the nodes at the beginning of the kernel - map copied_node_map; - if(cursor == nullptr) { - copied_node_map = CopyComputation(nodes_to_copy, indices); - } else { - ExecuteExpressionBefore(cursor, [&]() { - copied_node_map = CopyComputation(nodes_to_copy, indices); - }); - } - return copied_node_map; -} - -void IR::CopyArguments(ArgEdges args_to_copy, Node* cursor) -{ - unordered_set nodes_to_copy; - for (auto& [arg, out] : args_to_copy) { - nodes_to_copy.insert(arg.second); - } - - // copy all the nodes at the beginning of the kernel - map copied_node_map; - unordered_map indices; - copied_node_map = CopyNodesWithIndex(nodes_to_copy, indices, cursor); - - ReplaceArgs(args_to_copy, copied_node_map); -} - -void IR::MoveShapeOutsideKernels() { - UpdateGraph(); - // find all nodes that are used as shapes and are inside kernels - map nodes_to_copy; - for (auto node = begin(); !node.end(); node.next()) { - Node* kernel = node->GetParent("kernel"); - if (kernel == *node) { //if returns itself, then no kernel parent found - continue; - } - - // go over all outputs arguments - for (auto [edge, to] : node->args.Outputs()) { - auto& [id, from] = edge; - if (id.first != ArgType::Shape) { - continue; - } - // add the node to the set - nodes_to_copy[node.get()] = kernel; - } - } - - for (auto [ node, kernel ] : nodes_to_copy) { - //get all output arguments that are shapes - ArgEdges args_to_copy; - int earliest_output_index = INT_MAX; - Node* earliest_output = nullptr; - for (auto [edge, to] : node->args.Outputs()) { - auto& [id, from] = edge; - if (id.first == ArgType::Shape) { - args_to_copy.insert(ArgEdge(Arg(id, node), to)); - - //get the earliest output - if (to->index_ < earliest_output_index) { //wat - earliest_output_index = to->index_; - earliest_output = to; - } - } - } - - Node* common_parent = earliest_output->GetNodeWithCommonParent(kernel); - - // copy shape computation and put it before the earliest output (outside of the kernel if its inside) - CopyArguments(args_to_copy, common_parent); - ApplyChanges(false); - } -} - - -/// -/// Get all inputs of this program in the IR -/// -void IR::GetInputList() { - int input_memory_index = 0; - //MUST BE IN ORDER - for (auto node = begin(); !node.end(); node.next()) { - if (node->flags.has(NodeProp::InputMemory)) { - shape_memory_map[*node] = {}; - // add shapes to the memory inputs - for (int i = 0; i < node->args.Count(ArgType::Shape); i++) { - Node* shape_node = node->args.Get(ArgType::Shape, i); - shape_memory_map[*node][i] = shape_node; - } - - // set input memory index - int input_index = input_memory_index++; - // add shapes to the memory inputs - input_memory_map[input_index] = *node; - node->flags.set(NodeProp::InputMemory, (int64_t)input_index); - //if any of the inputs are "input_shape" then we need to add the input index to them - for (auto& [arg, from] : node->args.Inputs()) { - if (arg.first == ArgType::Shape && from->name == "input_shape") { - if(!from->flags.has(NodeProp::InputShapeMemory)) { //ONLY FIRST TIME - from->flags.set(NodeProp::InputShapeMemory, (int64_t)input_index); - } - } - } - } - } -} - -/// -/// Get all outputs of this program in the IR -/// -void IR::GetOutputList() { - for (auto node = begin(); !node.end(); node.next()) { - if (node->flags.has(NodeProp::OutputMemory)) { - if (!node->op->HasAllTypes(OpProp::Memory)) { - throw std::runtime_error( - "Compilation error: output is not a memory node"); // all outputs - // should be - // memory nodes - // at this point - } - output_memory_map[(int)node->flags.get(NodeProp::OutputMemory)] = *node; - } - if (node->op->HasAllTypes(OpProp::Modifier, OpProp::MemoryOp)) { - if (!node->HasParent("kernel")) { - writebacks++; - } - } else if (node->op->HasAllTypes(OpProp::Load, OpProp::MemoryOp)) { - if (!node->HasParent("kernel")) { - readbacks++; - } - } - } -} - -/// -/// Compute statistics about the IR -/// -void IR::ComputeStatistics() { - for (auto node = begin(); !node.end(); node.next()) { - if (node->name == "memory") { - bool is_input = node->flags.has(NodeProp::InputMemory); - bool is_output = node->flags.has(NodeProp::OutputMemory); - if (is_input) { - input_memory_count++; - } - if (!is_input && !is_output) { - temp_memory_count++; - } - } - } - - //Check if output memory map has all the outputs - if (output_memory_map.size() < output_memory_count) { - throw std::runtime_error("Output memory map does not have all the outputs, some got lost"); - } else if (output_memory_map.size() > output_memory_count) { - throw std::runtime_error("Output memory map has more outputs than expected"); - } -} - - - -unordered_set IR::GetDependencies(unordered_set nodes) { - unordered_set dependencies; - std::function dfs = [&](Node* node) - { - if (dependencies.contains(node)) { - return; - } - - dependencies.insert(node); - - //all inputs of this node are used - for (auto& [arg, from] : node->args.Inputs()) { - dfs(from); - } - - //if the node is a memory node or used as memory, then all outputs are used - for (auto [edge, to] : node->args.Outputs()) { - auto& [id, from] = edge; - if (to->args.IsChangingInput(id)) { - dfs(to); - } - } - }; - - for(auto node : nodes) { - dfs(node); - } - - return dependencies; -} - -void IR::ComputeNodeCost() -{ - for (auto node = begin(); !node.end(); node.next()) { - bool is_memory = node->op->HasAllTypes(OpProp::Memory); - unordered_map input_costs; - for (auto& [id, from] : node->args.Inputs()) { - if (id.first != ArgType::Memory && - (id.first != ArgType::Shape && !is_memory)) { - input_costs[from] = from->cost_; - } - } - float input_cost = node->op->GetCost(); - for (auto& input : input_costs) { - input_cost += abs(input.second); - } - node->cost_ = input_cost; - - //go over outputs and check if it has any load operations - bool is_used_as_memory = false; - for (auto [edge, to] : node->args.Outputs()) { - auto& [id, from] = edge; - if(id.first == ArgType::Memory) { - is_used_as_memory = true; - break; - } - } - - if(is_used_as_memory && input_cost > 128.0) { - node->flags.set(NodeProp::NoCopyFusion); - } - } -} - -map IR::GetKernelOutputs(Node *kernel) -{ - map node_output; - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - bool is_output = node->flags.has(NodeProp::OutputMemory); - ArgEdges outputs = ArgEdges(); - - for (auto [edge, to] : node->args.Outputs()) { - auto& [id, from] = edge; - if (to == nullptr) continue; - // if is a shape or memory argument, then skip (shape is loaded on CPU) - if (id.first == ArgType::Shape) continue; - if (!to->HasParent(kernel)) { - outputs.emplace(Arg(id, *node), to); - is_output = true; - } - } - - if (is_output) { - node_output[*node] = outputs; - } - } - - return node_output; -} - -string IR::PrintListing(map node_debug) const { - return GetOperationListing(*this, false, node_debug) + "\n\n"; -} - -string IR::GetNodeListing(Node* node) const { - return GetNodeString(node, true); -} - -/// -/// Copy given nodes -/// -/// target nodes to copy -/// if given, the arguments to replace -/// if given, the indices to use -/// if given, the target nodes -/// if true, all nodes and their arguments must be copied -/// mappings between the original nodes and the copied nodes -map IR::CopyNodes( - set nodes_to_copy, - unordered_map argument_replacements, - unordered_map indices, - unordered_set targets, bool must_copy_all) { - - // if we have indices, we are basically rerunning the computation with a - // different set of indices (of possible different shape) - bool can_change_shape = !indices.empty(); - NodeArguments shape_args = NodeArguments(); - if (can_change_shape) { - // get first index - int first_index = indices.begin()->first; - Node* first_index_node = indices.at(first_index); - shape_args = first_index_node->args.GetArguments(ArgType::Shape); - } - - if (nodes_to_copy.empty()) { - return {}; - } - - //TODO: figure out a better heuristic for this, or add as a warning - //(need to implement a logging system) - // if (nodes_to_copy.size() > 1024) { - // throw std::runtime_error( - // "Copy Nodes: Copying too many nodes, something is probably " - // "wrong. Number of nodes to copy: " + - // to_string(nodes_to_copy.size())); - // } - - // copy the nodes - map copied_node_map; - for (auto node = begin(); !node.end(); node.next()) { - if (!nodes_to_copy.contains(node.get())) { - continue; - } - - Node* new_node; - - // if we have the index, use it instead - bool is_dim = node->name == "dim_id"; - bool no_index = true; - if (is_dim) { - int dim = node->data[0]; - if (indices.contains(dim)) { - Node* new_index = indices.at(dim); - if(new_index == nullptr) { - throw std::runtime_error("Copy Nodes: New index is null for node " + node->name); - } - new_node = new_index; - no_index = false; - } - } - - if (no_index) { - // create new arguments - NodeArguments new_args; - for (auto& [arg, from]: node->args.Inputs()) { - auto& [type, index] = arg; - if (can_change_shape && type == ArgType::Shape) { - continue; - } - - // if shape or memory argument, then no need to use copied node - if (node->args.CannotCopyArgument(arg) && !targets.contains(from) && !argument_replacements.contains(from)) { - if(from == nullptr) { - throw std::runtime_error("Copy Nodes: From is null for node " + node->name); - } - new_args[arg] = from; - continue; - } - - Node* new_from = from; - - if (argument_replacements.contains(from)) { - new_from = argument_replacements[from]; - } else if (nodes_to_copy.contains(from)) { - if(!copied_node_map.contains(from)) { - throw std::runtime_error("Copy Nodes: No replacement for node " + from->name); - } - new_from = copied_node_map[from]; - } else if (must_copy_all) { - throw std::runtime_error("Copy Nodes: No replacement for node " + from->name + " but we must copy all nodes"); - } - - if(new_from == nullptr) { - throw std::runtime_error("Copy Nodes: New from is null for node " + from->name); - } - - // create new argument - new_args[arg] = new_from; - } - - if (can_change_shape) { - new_args.insert(shape_args.begin(), shape_args.end()); - } - - // create new node - Tensor* tensor = Tensor::GetCopy(*node->GetTensor(), new_args); - new_node = tensor->node_; - } - - if(new_node == nullptr) { - throw std::runtime_error("Copy Nodes: New node is null for node " + node->name); - } - - copied_node_map[node.get()] = new_node; - } - - return copied_node_map; -} - - -void IR::AddNodeLoadOperations(Node* node, Node* kernel, Tensors indices) { - for (auto& [arg, input_node] : node->args.InputsCopy()) { - if (arg.first == ArgType::Memory || arg.first == ArgType::Shape) - continue; - - bool is_in_a_kernel = input_node->HasParent("kernel"); - bool is_outside = !input_node->HasParent(kernel); - bool is_memory = input_node->op->HasAllTypes(OpProp::Memory); - - if (is_memory || (is_in_a_kernel && is_outside)) { - // load the memory node before this node - ExecuteExpressionBefore(node, [&]() { - Tensor& loaded = Tensor::Load(*input_node->GetTensor(), indices, IndexingMode::Unsafe); - node->args.UpdateArgument(arg, loaded.node_); - }); - } - } -} - -void IR::AddKernelGlobalLoadOperations() { - // get kernels - vector kernels = GetNodesOfType("kernel"); - for (auto kernel : kernels) { - - // replace all inputs pointing to memory nodes with the memory node - unordered_set nodes_to_load; - unordered_map load_arguments; - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - for (auto& [arg, input_node] : node->args.Inputs()) { - if (arg.first == ArgType::Memory || arg.first == ArgType::Shape) - continue; - - bool is_in_a_kernel = input_node->HasParent("kernel"); - bool is_outside = !input_node->HasParent(kernel); - bool is_memory = input_node->op->HasAllTypes(OpProp::Memory); - - if (is_memory || (is_in_a_kernel && is_outside)) { - nodes_to_load.insert(input_node); - load_arguments[input_node].insert(ArgEdge(Arg(arg, input_node), node.get())); - } - } - } - - for (auto node : nodes_to_load) { - // load the memory node at the beginning of the kernel - ExecuteExpressionFirstChild(kernel, [&]() { - Tensor& loaded = Tensor::Load(*node->GetTensor(), {}, IndexingMode::Unsafe); - for (auto [in, out] : load_arguments[node]) { - auto& [arg, from] = in; - out->args.UpdateArgument(arg, loaded.node_); - } - }); - } - } - - UpdateGraph(); -} - - -void IR::AddMemoryOpIndices() { - // get kernels - vector kernels = GetNodesOfType("kernel"); - for (auto kernel : kernels) { - // get kernel shape arguments - NodeArguments shape_args = kernel->args.GetArguments(ArgType::Shape); - - Tensors indices = Tensors(); - // add dimension index nodes - ExecuteExpressionFirstChild(kernel, [&]() { - for (int i = 0; i < shape_args.size(); i++) { - indices.push_back(&Tensor::Index(shape_args, i)); - } - }); - int kernel_dim = (int)shape_args.size(); - - // replace all inputs pointing to memory nodes with the memory node - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - if (!node->op->HasAllTypes(OpProp::MemoryOp)) { - continue; - } - - Node* input_node = node->args.Get(ArgType::Memory); - map shape = input_node->args.GetTensors(ArgType::Shape); - - int memory_dim = (int)shape.size(); - ExecuteExpressionBefore(node.get(), [&]() { - for (int i = 0; i < memory_dim; i++) { - if (node->args.Has(ArgType::Index,i)) { - continue; - } - if (memory_dim > kernel_dim) { - throw std::runtime_error( - "Memory dimension is greater than kernel dimension, we can't " - "implicitly broadcast"); - } - const Tensor* index = nullptr; - // if the shape is 1, then we broadcast and set the index to 0 - if (isConstantAndEqualTo(shape[i], 1.0)) { - index = &Tensor::Constant(0); - } else { - index = indices[i]; - } - node->args.AddArgument(ArgType::Index, i, index->node_); - } - }); - } - } - - UpdateGraph(); -} - -void IR::AddKernelGlobalStoreOperations() { - UpdateGraph(); - // get kernels - vector kernels = GetNodesOfType("kernel"); - - // go over all outputs of each kernel and create memory nodes to store the - // output - for (auto kernel: kernels) { - map node_output = GetKernelOutputs(kernel); - - for (auto [output, args] : node_output) { - // if the output is already a memory node, then skip - if (output->op->HasAllTypes(OpProp::Memory)) { - continue; - } - - Node* mem; - // add memory node before this kernel - ExecuteExpressionBefore(kernel, [&]() { - mem = Tensor::Memory(kernel->args.GetArguments(ArgType::Shape), output->format).node_; - mem->debug_name = output->debug_name; - - if (output->flags.has(NodeProp::OutputMemory)) { - mem->flags.copy_all_given(output->flags, { NodeProp::OutputMemory }); - output->flags.remove(NodeProp::OutputMemory); - } - }); - - // go over all outputs of this node and replace their input with the - // memory node - for (auto& [arg, to] : args) { - auto id = arg.first; - if (id.first != ArgType::Shape && - id.first != ArgType::Memory) { - // if not a memory or shape argument, then the memory needs to be - // loaded before the node - ExecuteExpressionBefore(to, [&]() { - Tensor& loaded = Tensor::Load(*mem->GetTensor(), {}, IndexingMode::Unsafe); - // the node must now use the loaded value - to->args.UpdateArgument(id, loaded.node_); - }); - } else { - // otherwise the memory can be used directly - to->args.UpdateArgument(id, mem); - } - } - - //get last modification of the memory - Node* last_mod = output->GetFinalVersion(); - //get the parent of the last modification on the same level as the memory - Node* last_mod_parent = last_mod->GetNodeWithCommonParent(output); - - // add store node after the last modification on the same level as the memory - ExecuteExpressionAfter(last_mod_parent, [&]() { - // add store node after this node - Tensor* store = &Tensor::Store(*mem->GetTensor(), *output->GetTensor(), {}, IndexingMode::Unsafe); - }); - } - } - - // replace all inputs pointing to memory nodes with the memory node - for (auto node = begin(); !node.end(); node.next()) { - bool is_memory = node->op->HasAllTypes(OpProp::Memory); - - for (auto& [id, from] : node->args.InputsCopy()) { - if (id.first == ArgType::Memory || - (id.first == ArgType::Shape && !is_memory)) - continue; - - if (from->op->HasAllTypes(OpProp::Memory)) { - // load the memory node before this node - ExecuteExpressionBefore(node.get(), [&]() { - Tensor& loaded = Tensor::Load(*from->GetTensor(), {}, IndexingMode::Unsafe); - node->args.UpdateArgument(id, loaded.node_); - }); - } - } - } - - UpdateGraph(); -} - -void IR::AddMemoryDeallocation() -{ - vector memory_nodes = GetNodesOfType("memory"); - - // go over all outputs of each memory and and put a deallocation node after the last time it is used - for (auto memory : memory_nodes) { - // skip input and output memories, they are deallocated manually - if (memory->flags.has(NodeProp::InputMemory)) { - continue; - } - - Node* last_output = nullptr; - int last_output_index = -1; - - bool is_an_output = false; - - //do a dfs to find the last output - std::function dfs = [&](Node* node) { - if (node->flags.has(NodeProp::OutputMemory)) { - is_an_output = true; - return; - } - - for (auto [edge, to] : node->args.Outputs()) { - auto& [id, from] = edge; - if (to->op->HasAllTypes(OpProp::MemoryReuse)) { - dfs(to); - } else { - if (last_output_index < to->index_) { - last_output_index = to->index_; - last_output = to; - } - } - } - }; - - dfs(memory); - - if (is_an_output) { - continue; - } - - // need to add deallication in the same scope as the allocation - Node* deallocation_point = last_output->GetNodeWithCommonParent(memory); - - // add deallocation node after the last time the memory is used - ExecuteExpressionAfter(deallocation_point, [&]() { - Tensor* deallocate = &Tensor::Deallocate(*memory->GetTensor()); - }); - } - - UpdateGraph(); -} - -// compute the flat index (in C-order) -Tensor* ComputeFlatIndex(NodeArguments memory_shape, Tensors indices, map idx, int memory_dim, IndexingMode mode = IndexingMode::Clamp) -{ - if (memory_dim == 0) - { - return &Tensor::Constant(0); - } - - int kernel_dim = (int)indices.size(); - - function get_shape = [&](int dim) { - dim = memory_dim - dim - 1; - return memory_shape[ArgID(ArgType::Shape, dim)]->GetTensor(); - }; - - // function to get index for given dimension, if not found then return - // default dim index - function get_index = [&](int dim) { - int idxdim = memory_dim - dim - 1; - Tensor* out; - const Tensor& shape = *get_shape(dim); - if (idx.find(idxdim) != idx.end()) { - out = const_cast(idx[idxdim]); - } else { - throw std::runtime_error("Finalize memory indexing: node index not found for dimension " + to_string(idxdim) + " in memory node with dimensions " + to_string(memory_dim)); - } - - //if index is uint then cast it to int - if (out->node_->format.type == Uint) { - out = &Tensor::toint(*out); - } - - switch (mode) - { - case IndexingMode::Clamp: - return &Tensor::clamp( - *out, TensorFrost::Tensor::Constant(0), - shape - TensorFrost::Tensor::Constant(1)); - case IndexingMode::Repeat: - return &(*out - (*out / shape) * shape); - case IndexingMode::Unsafe: - return out; - default: //TODO (Moroz): add other modes - throw std::runtime_error("Finalize memory indexing: invalid tensor indexing mode"); - } - }; - - Tensors shape = Tensors(); - Tensors index = Tensors(); - for (int i = 0; i < memory_dim; i++) { - shape.push_back(get_shape(i)); - index.push_back(get_index(i)); - } - - return &Tensor::FlatIndex(shape, index); -} - -void IR::ReplaceDimNodes(Node* kernel, Tensors indices, int dims) -{ - // replace all dim nodes with the corresponding index node - unordered_set nodes_to_remove; - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - if (node->name == "dim_id") { // remove the dim node - nodes_to_remove.insert(node.get()); - } - else - { - //go over node inputs and replace dim nodes with index nodes - for (auto& [id, from] : node->args.InputsCopy()) { - if (from->name == "dim_id") { - int dim = from->data[0]; - Node* index_node = nullptr; - if (dim >= dims) { //if dim node of dimension greater than the number of dimensions its 0 - ExecuteExpressionBefore(node.get(), [&]() { - index_node = Tensor::Constant(0).node_; - }); - } else { - index_node = indices[dim]->node_; - } - - // replace the dim node with the index node - node->args.UpdateArgument(id, index_node); - } - } - } - } - - // remove all dim nodes - for (auto* node : nodes_to_remove) { - RemoveNode(node); - } -} - -void IR::MultiDimensionalModeIndices(vector& indices, Node* kernel_, int dims, Tensors kernel_shape) -{ - //add dim_id nodes at the beginning of the kernel - ExecuteExpressionFirstChild(kernel_, [&]() { - for (int i = 0; i < dims; i++) { - indices[i] = &Tensor::Index(kernel_shape, i); - } - }); -} - -Tensors ComputeIndicesFromBlockIndex(Tensor* block_index, Node* kernel, - Tensors kernel_shape, int dims) { - //compute in-block index - vector block_size = kernel->group_size; - int block_dim = (int)block_size.size(); - Tensors block_size_tensors = {}; - for (int i = 0; i < block_dim; i++) { - block_size_tensors.push_back(&Tensor::Constant(block_size[i])); - } - vector in_block_indices; - for (int i = 0; i < block_dim; i++) { - in_block_indices.push_back(&block_index->BlockThreadIndex(i)); - } - - //compute out-of-block index - Tensors blocks_shape = {}; - for (int i = 0; i < block_dim; i++) { - const Tensor block_size = *block_size_tensors[i]; - const Tensor shape = *kernel_shape[i]; - Tensor& ceil = (shape + block_size - Tensor::Constant(1)) / block_size; - blocks_shape.push_back(&ceil); - blocks_shape[i]->SetDebugName("blocks_shape_" + to_string(i)); - } - for (int i = block_dim; i < dims; i++) { - blocks_shape.push_back(kernel_shape[i]); - blocks_shape[i]->SetDebugName("blocks_shape_" + to_string(i)); - } - Tensors out_block_indices = Tensor::IndicesFromFlatIndex(block_index, blocks_shape); - - //combine the final indices - Tensors indices = {}; - - for (int i = 0; i < block_dim; i++) { - indices.push_back(&(*out_block_indices[i] * *block_size_tensors[i] + *in_block_indices[i])); - indices[i]->SetDebugName("index_" + to_string(i)); - } - - for (int i = block_dim; i < dims; i++) { - indices.push_back(out_block_indices[i]); - indices[i]->SetDebugName("index_" + to_string(i)); - } - - return indices; -} - -Tensor* IR::LinearBlockModeIndices(Tensors& indices, Node* kernel_, int dims, Tensors kernel_shape) -{ - Tensor* block_index = nullptr; - Tensor* if_tensor = nullptr; - ExecuteExpressionFirstChild(kernel_, [&]() { - block_index = &kernel_->GetTensor()->BlockIndex(); - - if(kernel_->group_size.size() == 0) { //if group size is not set, then set it to default - switch (dims) - { - case 1: - kernel_->group_size = {256}; - break; - case 2: - kernel_->group_size = {16, 16}; - break; - case 3: - kernel_->group_size = {8, 8, 8}; - break; - default: - kernel_->group_size = {8, 8, 8}; - } - - //if the dimensions are known, then use the minimum of the group size and the shape to avoid useless computation - int group_dim = (int)kernel_->group_size.size(); - for (int i = 0; i < group_dim; i++) { - int shape = kernel_shape[i]->TryGetConstant(); - if (shape > 0) { - kernel_->group_size[i] = min(kernel_->group_size[i], shape); - } - } - } - - indices = ComputeIndicesFromBlockIndex(block_index, kernel_, kernel_shape, dims); - - //add a check for if inside the dispatch - Tensor* inside_dispatch = &(*indices[0] < *kernel_shape[0]); - for (int i = 1; i < dims; i++) { - inside_dispatch = &(*inside_dispatch && *indices[i] < *kernel_shape[i]); - } - inside_dispatch->SetDebugName("is_inside_dispatch"); - - //put an if condition - if_tensor = &Tensor::If(*inside_dispatch); - if_tensor->SetDebugName("if_inside_dispatch"); - }); - - ReplaceDimNodes(kernel_, indices, dims); - - return if_tensor; -} - -void IR::ComputeAddress(Node* node, Tensors indices) -{ - // get the input memory node - const Tensor* memory = node->args.GetTensor(ArgType::Memory); - - NodeArguments memory_shape = memory->node_->args.GetArguments(ArgType::Shape); - - int memory_dim = (int)memory_shape.size(); - - // get the index nodes - map idx = node->args.GetTensors(ArgType::Index); - - if (idx.empty()) - { - node->indexing_mode_ = IndexingMode::Unsafe; //we can guarantee that the index is in bounds - } - - auto flat_index= ComputeFlatIndex(memory_shape, indices, idx, memory_dim, node->indexing_mode_); - - // TODO(Moroz): add wrap mode - - // remove the index node edges - node->args.RemoveArguments(ArgType::Index); - - // add the flat index node edge - node->args.AddArgument(ArgType::Index, 0, flat_index->node_); -} - -void IR::FinalizeMemoryIndexing() { - vector kernels = GetNodesOfType("kernel"); - - vector dispatch_checks; - - for (auto kernel : kernels) { - Node* shape_node = kernel; - if (shape_node == nullptr) continue; - // load kernel shape - map kernel_shape_map = shape_node->args.GetTensors(ArgType::Shape); - Tensors kernel_shape; - for (auto& shape : kernel_shape_map) { - kernel_shape.push_back(shape.second); - } - - if (kernel_shape.empty()) { - // can skip if no kernel shape - no index - continue; - } - - // compute the index for each dimension - int dims = (int)kernel_shape.size(); - Tensors indices = Tensors(dims); - dispatch_checks.push_back( - LinearBlockModeIndices(indices, kernel, dims, kernel_shape)); - - // go over all nodes that take an index as input (e.g. load, store, atomic) - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - if (node->op->HasAllTypes(OpProp::MemoryOp)) { - if (node->flags.has(NodeProp::LocalMemoryOp)) { - continue; - } - ExecuteExpressionBefore(*node, [&]() { ComputeAddress(node.get(), indices); }); - } - } - } - - //now compute address for all nodes that are not in a kernel - for (auto node = begin(); !node.end(); node.next()) { - if (!node->HasParent("kernel") && node->op->HasAllTypes(OpProp::MemoryOp)) { - ExecuteExpressionBefore(node.get(), [&]() { - Tensors indices = {}; - ComputeAddress(node.get(), indices); - }); - } - } - - for (auto check : dispatch_checks) { - // put the rest of the kernel starting from if_node->next_ as a child - // MoveNodeTo(if_node->node_->child, if_node->node_->next); - Node* if_node = check->node_; - if_node->child = if_node->next; - if_node->next->prev = nullptr; - if_node->next->parent = if_node; - if_node->next = nullptr; - } - - UpdateGraph(); -} - - -void IR::TryReplaceModificationsWithVersions() -{ - UpdateGraph(); - - //get all "set" nodes - vector nodes = GetNodesOfType("set"); - - unordered_set nodes_to_remove; - - for (auto set_node : nodes) { - //look up the memory node - Node* memory_node = set_node->args.Get(ArgType::Memory); - Node* input_value = set_node->args.Get(ArgType::Input); - - //if this node has the same parent as the memory node, then it can be replaced with a version - if (memory_node->parent == set_node->parent) { - //replace the set node with the memory node - ExecuteExpressionBefore(set_node, [&]() { - Tensor& copied = Tensor::copy(*input_value->GetTensor()); - Node* copynode = copied.node_; - memory_node->ReplaceThisWithGivenNode(copynode, set_node->index_, true); - nodes_to_remove.insert(set_node); - }); - } - - UpdateIndex(); - } - - UpdateGraph(); - - // remove all nodes that are not used - for (auto* node : nodes_to_remove) { - RemoveNode(node); - } - - UpdateGraph(); -} - -void IR::ApplyChanges(bool update_graph, const Node* uroot) { - for (auto& [arg, out] : edgesToUpdate) { - Node* from = arg.second; - if (!replacementNodes.contains(from)) { - throw std::runtime_error("No replacement node found for node " + from->name); - } - Node* to = replacementNodes[from]; - out->args.UpdateArgument(arg.first, to); - } - - // remove all nodes that are not used - for (auto* node : removedNodes) { - RemoveNode(node); - } - - if (update_graph) { - UpdateGraph(uroot); - } - - ClearChanges(); -} - -void IR::ClearChanges() { - edgesToUpdate.clear(); - removedNodes.clear(); - replacementNodes.clear(); -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Compiler/Steps/Optimization.cpp b/TensorFrost/Compiler/Steps/Optimization.cpp deleted file mode 100644 index 64c15813..00000000 --- a/TensorFrost/Compiler/Steps/Optimization.cpp +++ /dev/null @@ -1,760 +0,0 @@ -#include "Compiler/KernelGen.h" - -namespace TensorFrost { - -bool isConstantAndEqualTo(const Tensor* tensor, float value) { - if (tensor->node_->name != "const" || tensor->node_->flags.has(NodeProp::Modified)) { - return false; - } - - switch (tensor->node_->format.type) { - case TFType::Float: - return AsFloat(tensor->node_->data[0]) == value; - case TFType::Int: - return AsInt(tensor->node_->data[0]) == value; - case TFType::Uint: - return tensor->node_->data[0] == value; - default: - throw std::runtime_error("Unexpected type in isConstantAndEqualTo"); - } -} - -bool isConstant(const Tensor* tensor) { - return tensor->node_->name == "const" && !tensor->node_->flags.has(NodeProp::Modified); -} - -Tensor* ApplyMultiOP(const Tensor* a, const Tensor* b, std::function opF32, std::function opI32, std::function opU32) { - switch (a->node_->format.type) { - case TFType::Float: - return &Tensor::Constant(opF32(AsFloat(a->node_->data[0]), AsFloat(b->node_->data[0]))); - case TFType::Int: - return &Tensor::Constant(opI32(AsInt(a->node_->data[0]), AsInt(b->node_->data[0]))); - case TFType::Uint: - return &Tensor::Constant(opU32(a->node_->data[0], b->node_->data[0])); - default: - throw std::runtime_error("Unexpected type in ApplyMultiOP"); - } -} - -Tensor* ApplyUnaryOP(const Tensor* a, std::function opF32, std::function opI32, std::function opU32) { - switch (a->node_->format.type) { - case TFType::Float: - return &Tensor::Constant(opF32(AsFloat(a->node_->data[0]))); - case TFType::Int: - return &Tensor::Constant(opI32(AsInt(a->node_->data[0]))); - case TFType::Uint: - return &Tensor::Constant(opU32(a->node_->data[0])); - default: - throw std::runtime_error("Unexpected type in ApplyUnaryOP"); - } -} - -#define ApplyOP(v1, v2, op) ApplyMultiOP(v1, v2, [](float a, float b) { return a op b; }, [](int a, int b) { return a op b; }, [](uint a, uint b) { return a op b; }) -#define ApplyFUNC(v1, v2, func) ApplyMultiOP(v1, v2, [](float a, float b) { return func(a, b); }, [](int a, int b) { return func(a, b); }, [](uint a, uint b) { return func(a, b); }) -#define ApplyUOP(v1, op) ApplyUnaryOP(v1, [](float a) { return op a; }, [](int a) { return op a; }, [](uint a) { return op a; }) -#define ApplyUFUNC(v1, func) ApplyUnaryOP(v1, [](float a) { return func(a); }, [](int a) { return func(a); }, [](uint a) { return func(a); } - -void IR::OptimizeOperations() -{ - for (auto node = begin(); !node.end(); node.next()) { - //get node operation - const string op = node->name; - - //get inputs - map inputs = node->args.GetTensors(ArgType::Input); - ExecuteExpressionAfter(*node, [&]() { - const Tensor* result = nullptr; - if (op == "add") { - // if any are zero, replace with the other - if (isConstantAndEqualTo(inputs[0], 0.0F)) { - // replace with input 1 - result = inputs[1]; - } else if (isConstantAndEqualTo(inputs[1], 0.0F)) { - // replace with input 0 - result = inputs[0]; - } - - // if all are constants, replace with result - if (isConstant(inputs[0]) && isConstant(inputs[1])) { - // replace with result - result = ApplyOP(inputs[0], inputs[1], +); - } - } else if (op == "sub") { - // if any are zero, replace with the other - if (isConstantAndEqualTo(inputs[0], 0.0F)) { - // replace with negation of input 1 - result = &(-*inputs[1]); - } else if (isConstantAndEqualTo(inputs[1], 0.0F)) { - // replace with input 0 - result = inputs[0]; - } - - // if all are constants, replace with result - if (isConstant(inputs[0]) && isConstant(inputs[1])) { - // compute result - result = ApplyOP(inputs[0], inputs[1], -); - } - - //if both are the same node, then replace with zero - if (inputs[0]->node_ == inputs[1]->node_) { - result = &Tensor::Constant(0u, inputs[0]->node_->format); - } - } else if (op == "mul") { - // if any are zero, replace with zero - if (isConstantAndEqualTo(inputs[0], 0.0F) || - isConstantAndEqualTo(inputs[1], 0.0F)) { - // replace with zero - result = &Tensor::Constant(0u, inputs[0]->node_->format); - } - - // if any are one, replace with the other - if (isConstantAndEqualTo(inputs[0], 1.0F)) { - // replace with input 1 - result = inputs[1]; - } else if (isConstantAndEqualTo(inputs[1], 1.0F)) { - // replace with input 0 - result = inputs[0]; - } - - // if all are constants, replace with result - if (isConstant(inputs[0]) && isConstant(inputs[1])) { - // compute result - result = ApplyOP(inputs[0], inputs[1], *); - } - } else if (op == "div") { - // if first is zero, replace with zero - if (isConstantAndEqualTo(inputs[0], 0.0F)) { - // replace with zero - result = &Tensor::Constant(0u, inputs[0]->node_->format); - } - - // if second is one, replace with first - if (isConstantAndEqualTo(inputs[1], 1.0F)) { - // replace with input 0 - result = inputs[0]; - } - - // if all are constants, replace with result - if (isConstant(inputs[0]) && isConstant(inputs[1])) { - // compute result - result = ApplyOP(inputs[0], inputs[1], /); - } - } - else if (op == "clamp") { - // if all are constants, replace with result - if (isConstant(inputs[0]) && isConstant(inputs[1]) && isConstant(inputs[2])) { - // compute result - result = ApplyFUNC(inputs[0], inputs[1], max); - result = ApplyFUNC(result, inputs[2], min); - } - } - else if(op == "neg") { - if(isConstant(inputs[0])) { - result = ApplyUnaryOP(inputs[0], [](float a) { return -a; }, [](int a) { return -a; }, [](uint a) { return a; }); - } - } - else if(op == "dim_id") { //if the shape of the dimension is 1 then replace with 0 - int dim = node->data[0]; - const Tensor* shape = node->args.Get(ArgType::Shape, dim)->GetTensor(); - if(isConstantAndEqualTo(shape, 1.0F)) { - result = &Tensor::Constant(0u, TFTypeInt32); - } - } - //TODO (Moroz): add more optimizations - - // if computed optimized result, replace all node references with it - if (result != nullptr) - { - node->ReplaceThisWithGivenNode(result->node_, -1, false, false); - } - }); - } -} - -void IR::OptimizeHostValuesWithHints() -{ - for (auto node = begin(); !node.end(); node.next()) { - //if node inside kernel - skip - if(node->HasParent("kernel")) continue; - - ExecuteExpressionAfter(*node, [&]() { - const Tensor* result = nullptr; - - //if node has a max value hint, then replace it with it - if(node->flags.has(NodeProp::HintMaxValue)) { - int64_t max_value = node->flags.get(NodeProp::HintMaxValue); - result = &Tensor::Constant((uint)max_value, node->format); - } - - if (result != nullptr) - { - for (auto [edge, to] : node->args.OutputsCopy()) { - auto& [id, from] = edge; - //if(to->HasParent("kernel")) continue; #TODO (Moroz): check if this is needed - to->args.UpdateArgument(id, result->node_); - } - } - }); - } - - UpdateGraph(); -} - -void IR::RemoveUnusedOperations() { - unordered_set used_nodes; - //mark all output nodes as used - for (auto node = begin(); !node.end(); node.next()) { - if (node->flags.has(NodeProp::OutputMemory) || - node->flags.has(NodeProp::InputMemory) || - node->op->HasAllTypes(OpProp::Static)) { - used_nodes.insert(node.get()); - } - } - - used_nodes = GetDependencies(used_nodes); - - // remove all nodes that are not used - unordered_set nodes_to_remove; - for (auto node = begin(); !node.end(); node.next()) { - if (!used_nodes.contains(node.get())) { - if (!node->flags.has(NodeProp::InputMemory) && !node->flags.has(NodeProp::OutputMemory)) - { - nodes_to_remove.insert(node.get()); - } - } - } - - for (auto* node : nodes_to_remove) { - RemoveNode(node); - } - - UpdateGraph(); -} - - -void IR::RemoveUnusedKernels() -{ - vector kernels = GetNodesOfType("kernel"); - vector nodes_to_remove; - - for (auto kernel : kernels) { - // remove all kernel nodes that dont do anything - int memory_modifiers = 0; - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - if (node->op->HasAllTypes(OpProp::Modifier, OpProp::MemoryOp)) { - memory_modifiers++; - } - //if any output is outside the kernel, then the kernel is needed - for (auto [edge, to] : node->args.Outputs()) { - auto& [id, from] = edge; - if (!to->HasParent(kernel)) { - memory_modifiers++; - } - } - } - if (memory_modifiers == 0) nodes_to_remove.push_back(kernel); - } - - // remove all nodes that are not used - for (auto* node : nodes_to_remove) { - RemoveNode(node); - } - - UpdateGraph(); -} - -#define MAX_UNROLL_NODES 128 -void IR::UnrollLoops(int max_iterations) -{ - vector loops = GetNodesOfType("loop"); - - unordered_set loops_to_remove; - - for (auto loop : loops) { - //get inputs (begin, end, step) - map inputs = loop->args.GetTensors(ArgType::Input); - - //try get the constant values - bool is_const = isConstant(inputs[0]) && isConstant(inputs[1]) && isConstant(inputs[2]); - - bool has_other_loops = loop->HasChild("loop") || loop->HasParent("loop"); - bool has_child_kernel = loop->HasChild("kernel"); - - if (!is_const || has_other_loops) { - continue; - } - - int begin = inputs[0]->TryGetConstant(); - int end = inputs[1]->TryGetConstant(); - int step = inputs[2]->TryGetConstant(); - - //how many iterations to unroll - int iters = (end - begin) / step; - if (iters > max_iterations) { - continue; - } - - //get all children of the loop - vector children = GetChildren(loop); - - if (children.size() > MAX_UNROLL_NODES) { -#ifdef _RELWITHDEBINFO - cout << current_pass << ": Warning: Loop has too many children to unroll" << endl; -#endif - continue; - } - - //check if they are not keywords or have no children - set nodes_to_copy; - bool can_unroll = true; - for (auto child : children) { - if (child->op->class_ == OpClass::Keyword || child->child->valid()) { - can_unroll = false; - break; - } - nodes_to_copy.insert(child); - } - - if (!can_unroll) { - continue; - } - - //unroll the loop - ExecuteExpressionAfter(loop, [&]() { - for (int i = begin; i < end; i += step) { - unordered_map arg_remap; - Tensor* index = &Tensor::Constant(i); - //index->SetDebugName(loop->debug_name + "_unroll_" + to_string(i)); - arg_remap[loop] = index->node_; - CopyNodes(nodes_to_copy, arg_remap, {}, {}, false); - } - }); - - //mark the loop for removal - loops_to_remove.insert(loop); - } - - // remove all loops that are not used - for (auto* loop : loops_to_remove) { - RemoveNode(loop); - } - - UpdateGraph(); -} - -void IR::UnrollAtomicOperations() { - vector atomics = GetNodesOfType(OpProp::Scatter, OpProp::MemoryOp, OpProp::Modifier); - - vector nodes_to_remove; - for (auto node: atomics) { - if(node->flags.has(NodeProp::LocalMemoryOp)) continue; - - std::set unused_dimensions; - Node* next_node = node->next; - int dim = node->args.Count(ArgType::Shape); - for (int i = 0; i < dim; i++) { - unused_dimensions.insert(i); - } - - //get the indices of the scatter operation - NodeArguments indices = node->args.GetArguments(ArgType::Index); - //get dependencies of all indices - unordered_set index_nodes; - for (auto& [id, index] : indices) { - index_nodes.insert(index); - } - unordered_set dependencies = GetDependencies(index_nodes); - //if any of the dependencies are a dim_id node, then its dimension(data[0]) is used - for (auto dep : dependencies) { - if (dep->name == "dim_id") { - unused_dimensions.erase(dep->data[0]); - } - } - - int unused_count = (int)unused_dimensions.size(); - - if (unused_count == 0) { - continue; - } - - auto old_shape = node->args.GetTensors(ArgType::Shape); - - //reduce the dimensions of the scatter operation - const Tensor* current_reduce = node->args.Get(ArgType::Input, 0)->GetTensor(); - vector unused_dims_vec(unused_dimensions.begin(), unused_dimensions.end()); - //sort the dimensions in descending order - std::reverse(unused_dims_vec.begin(), unused_dims_vec.end()); - bool supported_operation = true; - - BeginScope(next_node); - - for(int udim : unused_dims_vec) { - if(node->name == "InterlockedAdd") { - current_reduce = &Tensor::Sum(*current_reduce, udim); - } else if (node->name == "InterlockedMin") { - current_reduce = &Tensor::Min(*current_reduce, udim); - } else if (node->name == "InterlockedMax") { - current_reduce = &Tensor::Max(*current_reduce, udim); - } else if (node->name == "InterlockedAnd") { - current_reduce = &Tensor::All(*current_reduce, udim); - } else if (node->name == "InterlockedOr") { - current_reduce = &Tensor::Any(*current_reduce, udim); - } else { - supported_operation = false; - break; - } - } - - if (!supported_operation) { - EndScope(); -#ifdef _RELWITHDEBINFO - cout << current_pass << ": Warning: Unsupported atomic operation " << node->name << endl; -#endif - continue; - } - - unordered_map old_to_new_dim; - int new_idx = 0; - - for (int i = 0; i < dim; i++) { - if (!unused_dimensions.contains(i)) { - old_to_new_dim[i] = current_reduce->Index(new_idx++).node_; - } else { - old_to_new_dim[i] = Tensor::Constant(current_reduce->GetShape(), 0, current_reduce->GetFormat()).node_; - } - } - - map copied_node_map = CopyNodesWithIndex(index_nodes, old_to_new_dim); - Tensors new_indices = Tensors(); - new_indices.resize(indices.size()); - for (auto& [id, index] : indices) { - new_indices[id.second] = copied_node_map[index]->GetTensor(); - } - - //get the memory to scatter to - const Tensor* memory = node->args.Get(ArgType::Memory)->GetTensor(); - Tensor* store_op = nullptr; - Tensor* old_value = nullptr; - if(false) { //TODO (Moroz): check if indexes are one-to-one, otherwise must use atomic operations - old_value = &Tensor::Load(*memory, new_indices); - Tensor* new_value = nullptr; - if(node->name == "InterlockedAdd") { - new_value = &(*old_value + *current_reduce); - } else if (node->name == "InterlockedMin") { - new_value = &Tensor::min(*old_value, *current_reduce); - } else if (node->name == "InterlockedMax") { - new_value = &Tensor::max(*old_value, *current_reduce); - } else if (node->name == "InterlockedAnd") { - new_value = &(*old_value & *current_reduce); - } else if (node->name == "InterlockedOr") { - new_value = &(*old_value | *current_reduce); - } - store_op = &Tensor::Store(*memory, *new_value, new_indices); - } else { //still use atomic operations - store_op = &Tensor::MemoryOp(node->name, memory, new_indices, current_reduce); - } - - nodes_to_remove.push_back(node); - - EndScope(); - } - - for (auto node : nodes_to_remove) { - RemoveNode(node); - } - - UpdateGraph(); -} - -#define MIN_SPLIT_SIZE 1024 -#define SPLIT_DIM_SIZE 128 - -void IR::OptimizeReductions() { - vector reductions = GetNodesOfType(OpProp::Algorithm, OpProp::Reduction); - - vector nodes_to_remove; - for (auto node : reductions) { - int axis = (int)node->data[0]; - //get the input tensor - const Tensor* input = node->args.Get(ArgType::Input, 0)->GetTensor(); - //get the shape of the tensor at the reduction axis - const Tensor* tensor = input->node_->args.Get(ArgType::Shape, axis)->GetTensor(); - //try to get the constant value - int axis_value = tensor->TryGetConstant(); - if (axis_value < 0) { -#ifdef _RELWITHDEBINFO - cout << current_pass << ": Warning: Can not apply reduction optimization on non-constant axis" << endl; -#endif - continue; - } - //if size of the axis is less than the minimum split size, then do not split - if (axis_value < MIN_SPLIT_SIZE) { - continue; - } - - ExecuteExpressionAfter(node, [&]() { - //split dimension into smaller chunks and reduce them sequentially - const Tensor* split = &Tensor::SplitDim(*input, SPLIT_DIM_SIZE, axis); - Tensor* result = &Tensor::ReductionOP(node->name, *split, axis, false); - result = &Tensor::ReductionOP(node->name, *result, axis, false); - node->ReplaceThisWithGivenNode(result->node_); - nodes_to_remove.push_back(node); - }); - } - - for (auto node : nodes_to_remove) { - RemoveNode(node); - } - - UpdateGraph(); -} - -#define MAX_KERNEL_COPY_COST 50000.0f -bool IR::OptimizeKernels() { - // get kernel data - vector kernels = GetNodesOfType("kernel"); - ComputeNodeCost(); - - bool changed = false; - // go over each kernel and copy computations outside the kernel if they are - // cheap enough - for (auto kernel : kernels) { - ArgEdges args_to_copy; - ArgEdges shape_args_to_copy; - // go over all nodes in the kernel and check if their inputs can be copied - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - // go over all inputs - for (auto& [arg, from]: node->args.Inputs()) { - bool inside_kernel = from->HasParent(kernel); - bool from_in_kernel = from->HasParent("kernel"); - - if(from->flags.has(NodeProp::NoCopyFusion)) continue; - - if (!inside_kernel && !node->args.CannotCopyArgument(arg)) - { - // check if input is cheap enough to copy - float input_cost = from->cost_; - if (input_cost == -1.0) { -#ifdef _RELWITHDEBINFO - cout << current_pass << ": Warning: Could not determine cost of node " << from->name << endl; -#endif - continue; - } - bool cheap_enough = input_cost >= 0.0f && input_cost < MAX_KERNEL_COPY_COST; - bool has_only_one_output = from->args.OutputCount() == 1; - if (cheap_enough || has_only_one_output) { - args_to_copy.insert(ArgEdge(Arg(arg, from), *node)); - } - } - //shape arguments can not be inside kernels - if (from_in_kernel && arg.first == ArgType::Shape) { - shape_args_to_copy.insert(ArgEdge(Arg(arg, from), *node)); - } - } - } - -// auto kernel_deps_old = ComputeKernelDependencies(kernel); - - //go over kernel shape arguments - for (int i = 0; i < kernel->args.Count(ArgType::Shape); i++) { - Node* shape_node = kernel->args.Get(ArgType::Shape, i); - bool from_in_kernel =shape_node->HasParent("kernel"); - if (from_in_kernel) { - shape_args_to_copy.insert(ArgEdge(Arg(ArgID(ArgType::Shape, i), shape_node), kernel)); - } - } - - //copy the nodes that are outside the kernel inside - CopyArguments(args_to_copy, kernel->child); - -// auto kernel_deps = ComputeKernelDependencies(kernel); - -// //if more than allowed then do not apply changes -// if (kernel_deps.size() > max_kernel_memory_dependencies && kernel_deps_old.size() < kernel_deps.size() && !must_copy_all) { -// ClearChanges(); -// #ifdef _RELWITHDEBINFO -// std::cout << current_pass << ": Warning: Discarding kernel optimization changes for kernel " << kernel->name << " with " << kernel_deps.size() << " dependencies while before it had " << kernel_deps_old.size() << " dependencies" << endl; -// #endif -// } - - //copy shape arguments before the kernel - CopyArguments(shape_args_to_copy, kernel); - if (!args_to_copy.empty()) { - changed = true; - } - - ApplyChanges(false); - } - - return !changed; -} - -#define MAX_LOAD_COPY 3000.0f -#define MAX_LOAD_COPY_COUNT 2 -#define MAX_LOAD_SIZE_RATIO 0.5f - -bool IR::OptimizeKernelLoadOperations() { - ComputeNodeCost(); - - vector kernels = GetNodesOfType("kernel"); - - unordered_set nodes_to_remove; - - size_t loads_fused = 0; - - for (auto kernel : kernels) { - ShapeInfo kernel_shape = ShapeInfo(kernel); - - unordered_map loads_to_copy; - unordered_set memory_inputs; - // go over all nodes in the kernel and check if their inputs can be copied - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - if (node->name != "load") continue; - - if (node->flags.has(NodeProp::NoLoadFusion)) continue; - - //get memory input - Node* memory_input = node->args.Get(ArgType::Memory); - - if(memory_input->flags.has(NodeProp::StopFusion)) continue; - - ShapeInfo memory_shape = ShapeInfo(memory_input); - - bool inside_kernel = memory_input->HasParent("kernel"); - if (!inside_kernel) continue; - - bool is_not_modified = !memory_input->flags.has(NodeProp::Modified); - if (!is_not_modified) continue; - - float kernel_size = ShapeInfo::GetSizeEstimate(kernel_shape); - float memory_size = ShapeInfo::GetSizeEstimate(memory_shape); - float size_ratio = kernel_size / memory_size; - - int output_count = (int)memory_input->args.OutputCount(); - //only fuse if this is used less than MAX_LOAD_COPY_COUNT times or we can reduce dimensionality by fusing - bool fusion_makes_sense = (output_count < MAX_LOAD_COPY_COUNT) || - (size_ratio <= MAX_LOAD_SIZE_RATIO) || memory_size == 1.0f; - bool cheap_enough = memory_input->cost_ >= 0.0f && - memory_input->cost_ < (MAX_LOAD_COPY / output_count); - - - //if the memory input is used only once and is not a memory node - if (cheap_enough && fusion_makes_sense) { - loads_to_copy[memory_input] = *node; - memory_inputs.insert(memory_input); - } - } - - if (loads_to_copy.empty()) continue; - - //auto kernel_deps_old = ComputeKernelDependencies(kernel); - - for (auto load : loads_to_copy) { - //get the load - Node* memory_input = load.first; - Node* load_node = load.second; - - //get the indices - unordered_map indices; - for (auto& [arg, from] : load_node->args.Inputs()) { - if (arg.first == ArgType::Index) { - indices[arg.second] = from; - } - } - - //copy the load node - map copied_node_map = CopyNodesWithIndex({ memory_input }, indices, load_node); - - - Tensors indices_tensors = Tensors(); - indices_tensors.resize(indices.size()); - for (auto& [index, node] : indices) { - indices_tensors[index] = node->GetTensor(); - } - - //go over all the copied nodes and add load nodes to their inputs that are outside the kernel - for (auto& [old_node, new_node] : copied_node_map) { - AddNodeLoadOperations(new_node, kernel, indices_tensors); - } - - Node* copied_load = copied_node_map[memory_input]; - //copy over the information from the original load node - copied_load->CopyMetadata(load_node); - - map replacements; replacements[load_node] = copied_load; - - ReplaceArgs(load_node->args.Outputs(), replacements); - } - - //auto kernel_deps = ComputeKernelDependencies(kernel); - -// //if more than allowed then do not apply changes -// if (kernel_deps.size() > max_kernel_memory_dependencies && kernel_deps_old.size() < kernel_deps.size()) { -// ClearChanges(); -// #ifdef _RELWITHDEBINFO -// std::cout << current_pass << ": Warning: Discarding kernel load fusion changes for kernel " << kernel->name << " with " << kernel_deps.size() << " dependencies, before it had " << kernel_deps_old.size() << " dependencies" << endl; -// #endif -// continue; -// } - - //remove the load node since it is not needed anymore - for (auto load : loads_to_copy) { - nodes_to_remove.insert(load.second); - } - - loads_fused += loads_to_copy.size(); - ApplyChanges(false); - } - - // remove the load nodes - for (auto node : nodes_to_remove) { - RemoveNode(node); - } - - UpdateGraph(); - - return loads_fused == 0; -} - - - -#define MAX_HOST_COPY_COST 8192.0f - -void IR::OptimizeHost() { - ComputeNodeCost(); - - //loop over all nodes and copy their arguments if they are cheap enough and inside kernels - for (auto node = begin(); !node.end(); node.next()) { - if (node->HasParent("kernel")) { - continue; - } - - ArgEdges args_to_copy; - // go over all inputs - for (auto& [arg, from] : node->args.Inputs()) { - bool inside_kernel = from->HasParent("kernel"); - - if (inside_kernel && !node->args.CannotCopyArgument(arg)) { - // check if input is cheap enough to copy - float input_cost = from->cost_; - if (input_cost == -1.0) { - //throw std::runtime_error("Cost has not been computed for node " + input.from_->get()->var_name); - continue; - } - bool cheap_enough = input_cost >= 0.0f && input_cost < MAX_HOST_COPY_COST; - bool has_only_one_output = from->args.OutputCount() == 1; - - if (cheap_enough || has_only_one_output) { - args_to_copy.insert(ArgEdge(Arg(arg, from), *node)); - } else { - throw std::runtime_error("Host optimization: Copy cost too high for node " + node->name + " with cost " + to_string(input_cost)); - } - } - } - - CopyArguments(args_to_copy, node.get()); - ApplyChanges(false); - } -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Frontend/Python/Definitions/PyModule.cpp b/TensorFrost/Frontend/Python/Definitions/PyModule.cpp deleted file mode 100644 index 219bda10..00000000 --- a/TensorFrost/Frontend/Python/Definitions/PyModule.cpp +++ /dev/null @@ -1,62 +0,0 @@ -#include -#include - -#include -#include - -namespace TensorFrost { - -class PyModule : public Module { -public: - using Module::Module; // Inherit constructors - - void assert_parameters() override { - PYBIND11_OVERRIDE(void, Module, assert_parameters); - } - - py::object loss(py::object X, py::object Y) override { - PYBIND11_OVERRIDE_PURE(py::object, Module, loss, X, Y); - } - - py::object forward(py::object X) override { - PYBIND11_OVERRIDE_PURE(py::object, Module, forward, X); - } -}; - -void ModuleDefinitions(py::module& m) { - py::class_(m, "Parameter") - .def(py::init&, TFDataFormat, float, float, bool>(), py::arg("shape"), py::arg("dtype") = TFType::Float, py::arg("random_scale") = -1.0f, py::arg("random_offset") = 0.0f, py::arg("optimize") = true) - .def_readwrite("shape", &Parameter::shape) - .def_readwrite("dtype", &Parameter::dtype) - .def_readwrite("random_scale", &Parameter::random_scale) - .def_readwrite("random_offset", &Parameter::random_offset) - .def("__repr__", [](const Parameter& p) { - return "Parameter(shape=" + std::to_string(p.shape.size()) + ", dtype=" + std::to_string(p.dtype.type) + "( " + std::to_string(p.dtype.size) + ") , random_scale=" + std::to_string(p.random_scale) + ", random_offset=" + std::to_string(p.random_offset) + ", optimize=" + std::to_string(p.optimize) + ")"; - }); - - py::class_(m, "ParameterArray") - .def(py::init<>()) - .def("__getitem__", &ParameterArray::getitem) - .def("__setitem__", &ParameterArray::setitem) - .def("items", &ParameterArray::items); - - py::class_(m, "Module") - .def(py::init(), py::arg("requires_grad") = true) - .def("__getattr__", &Module::getattr) - .def("__setattr__", &Module::setattr) - .def("hasattr", &Module::hasattr) - .def("param_requires_grad", &Module::param_requires_grad) - .def("initialize_input", &Module::initialize_input) - .def("initialize_parameters", &Module::initialize_parameters) - .def("initialize_parameters_native", &Module::initialize_parameters_native) - .def("parameters", &Module::parameters) - .def("named_parameters", &Module::named_parameters) - .def("requires_grads_list", &Module::requires_grads_list) - .def("create_input", &Module::create_input) - .def("update_parameters", &Module::update_parameters) - .def("assert_parameters", &Module::assert_parameters) - .def("loss", &Module::loss) - .def("forward", &Module::forward); -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Frontend/Python/Definitions/PyTensor.cpp b/TensorFrost/Frontend/Python/Definitions/PyTensor.cpp deleted file mode 100644 index 750b8c02..00000000 --- a/TensorFrost/Frontend/Python/Definitions/PyTensor.cpp +++ /dev/null @@ -1,199 +0,0 @@ -#include -#include - -#include - -namespace TensorFrost { - -void DefineOperator( - const std::string& pyname, py::class_& py_tensor, - const std::function& op) { - py_tensor.def(l_op(pyname).c_str(), - [op](const PyTensor& t, const PyTensor& t2) { - return PT(op(T(t), T(t2))); - }); - py_tensor.def(l_op(pyname).c_str(), [op](const PyTensor& t, const float f) { - return PT(op(T(t), Tensor::Constant(f))); - }); - py_tensor.def(l_op(pyname).c_str(), [op](const PyTensor& t, const int i) { - return PT(op(T(t), Tensor::Constant(i))); - }); - py_tensor.def(r_op(pyname).c_str(), [op](const PyTensor& t, const float f) { - return PT(op(Tensor::Constant(f), T(t))); - }); - py_tensor.def(r_op(pyname).c_str(), [op](const PyTensor& t, const int i) { - return PT(op(Tensor::Constant(i), T(t))); - }); -} - -#define LAMBDA_OP(op) \ - [](const Tensor& t1, const Tensor& t2) -> Tensor& { return t1 op t2; } - -void DefineOperators(py::class_& py_tensor) { - DefineOperator("add", py_tensor, LAMBDA_OP(+)); - DefineOperator("sub", py_tensor, LAMBDA_OP(-)); - DefineOperator("mul", py_tensor, LAMBDA_OP(*)); - DefineOperator("div", py_tensor, LAMBDA_OP(/)); - DefineOperator("truediv", py_tensor, LAMBDA_OP(/)); - DefineOperator("mod", py_tensor, LAMBDA_OP(%)); - DefineOperator("eq", py_tensor, LAMBDA_OP(==)); - DefineOperator("ne", py_tensor, LAMBDA_OP(!=)); - DefineOperator("lt", py_tensor, LAMBDA_OP(<)); - DefineOperator("le", py_tensor, LAMBDA_OP(<=)); - DefineOperator("gt", py_tensor, LAMBDA_OP(>)); - DefineOperator("ge", py_tensor, LAMBDA_OP(>=)); - DefineOperator("and", py_tensor, LAMBDA_OP(&&)); - DefineOperator("or", py_tensor, LAMBDA_OP(||)); - DefineOperator("xor", py_tensor, LAMBDA_OP(^)); - DefineOperator("lshift", py_tensor, LAMBDA_OP(<<)); - DefineOperator("rshift", py_tensor, LAMBDA_OP(>>)); - DefineOperator("and_", py_tensor, LAMBDA_OP(&)); - DefineOperator("or_", py_tensor, LAMBDA_OP(|)); - - py_tensor.def("__neg__", [](const PyTensor& t) { return PT(-T(t)); }); - py_tensor.def("__not__", [](const PyTensor& t) { return PT(!T(t)); }); - py_tensor.def("__invert__", [](const PyTensor& t) { return PT(~T(t)); }); - py_tensor.def("__pow__", [](const PyTensor& t, const PyTensor& t2) { - return PT(Tensor::pow(T(t), T(t2))); - }); - py_tensor.def("__pow__", [](const PyTensor& t, float f) { - return PT(Tensor::pow(T(t), Tensor::Constant(f))); - }); - py_tensor.def("__rpow__", [](const PyTensor& t, float f) { - return PT(Tensor::pow(Tensor::Constant(f), T(t))); - }); - py_tensor.def("__matmul__", [](const PyTensor& t, const PyTensor& t2) { - return PT(Tensor::Matmul(T(t), T(t2))); - }); -} - -void PyTensorDefinition(py::module& /*m*/, py::class_& py_tensor) { - // initializers - py_tensor.def(py::init()); - py_tensor.def(py::init()); - py_tensor.def(py::init()); - py_tensor.def(py::init()); - - // properties - py_tensor.def_property_readonly("shape", [](const PyTensor& t) { - return PyTensorsFromTensors(Reverse(t.Get().GetShape())); - }); - py_tensor.def_property_readonly( - "type", [](const PyTensor& t) { return t.Get().GetFormat(); }); - py_tensor.def_property_readonly("indices", [](const PyTensor& t) { - int dim = T(t).GetDimension(); - py::tuple indices(dim); - for (int i = 0; i < dim; i++) { - indices[i] = PT(T(t).Index(dim - i - 1)); - } - return indices; - }); - py_tensor.def_property_readonly("op_name", [](const PyTensor& t) { - return T(t).node_->name; - }); - py_tensor.def("try_get_constant", [](const PyTensor& t) { - if(T(t).node_->name != "const") { - throw std::runtime_error("Can not get constant from non-constant tensor"); - } - return T(t).TryGetConstant(); - }); - py_tensor.def("index",[](const PyTensor& t, int dim) { - int dims = T(t).GetDimension(); - return PT(T(t).Index(dims - dim - 1)); - }); - - py_tensor.def("block_index", [](const PyTensor& t) { - return PT(T(t).BlockIndex()); - }); - - py_tensor.def("block_thread_index", [](const PyTensor& t, int block_dim) { - return PT(T(t).BlockThreadIndex(block_dim)); - }); - - py_tensor.def("detach_grad", [](const PyTensor& t) { - t.Get().DetachGrad(); - return t; - }); - - py_tensor.def("pass_grad", [](const PyTensor& t) { - t.Get().PassGrad(); - return t; - }); - - py_tensor.def("stop_fusion", [](const PyTensor& t) { - t.Get().StopFusion(); - return t; - }); - - py_tensor.def("hint_range", [](const PyTensor& t, py::object min, py::object max) { - if(t.Get().node_->format == TFTypeFloat32) { - t.Get().HintRange(py::cast(min), py::cast(max)); - } else { - t.Get().HintRange(py::cast(min), py::cast(max)); - } - }, py::arg("min"), py::arg("max")); - - // operators - DefineOperators(py_tensor); - - //no way to overload normal setter - //TODO use python AST to generate these functions - py_tensor.def("set", - [](const PyTensor& t, const PyTensor& t2) { T(t).Set(T(t2)); }); - - py_tensor.def_property("val", [](const PyTensor& t) { return t; }, - [](PyTensor& t, const PyTensor& val) { T(t).Set(T(val)); }); - - // indexing - py_tensor.def("__getitem__", [](const PyTensor& t, const PyTensor& t1) { - Tensors indices; - indices.push_back(&t1.Get()); - return PyTensor(&t.Get(), indices); - }); - py_tensor.def("__getitem__", [](const PyTensor& t, py::tuple indices_tuple) { - Tensors indices = Reverse(TensorsFromTuple(indices_tuple)); - return PyTensor(&t.Get(), indices); - }); - - py_tensor.def("__setitem__", - [](const PyTensor& t, const PyTensor& t1, const PyTensor& t2) { - Tensors indices; - indices.push_back(&t1.Get()); - Tensor::Store(t.Get(), T(t2), indices); - }); - py_tensor.def("__setitem__", [](const PyTensor& t, py::tuple indices_tuple, - const PyTensor& t2) { - Tensors indices = Reverse(TensorsFromTuple(indices_tuple)); - Tensor::Store(t.Get(), T(t2), indices); - }); - - py_tensor.def("__setitem__", [](const PyTensor& t, const PyTensor& t1, pybind11::none none) { - //do nothing - }); - py_tensor.def("__setitem__", [](const PyTensor& t, py::tuple indices_tuple, pybind11::none none) { - //do nothing - }); - - // transpose - py_tensor.def("transpose", [](const PyTensor& t, int dim1, int dim2) { - return PT(Tensor::Transpose(T(t), -dim1-1, -dim2-1)); - }, py::arg("dim1") = -2, py::arg("dim2") = -1, "Transpose the tensor"); - - //transpose property - py_tensor.def_property_readonly("T", [](const PyTensor& t) { - return PT(Tensor::Transpose(T(t))); - }); - - py_tensor.def("__str__", [](const PyTensor& t) { - return GetNodeString(t.Get().node_); - }); - py_tensor.def("__repr__", [](const PyTensor& t) { - return GetNodeString(t.Get().node_); - }); - - py_tensor.def("set_debug_name", [](const PyTensor& t, const std::string& name) { - t.Get().SetDebugName(name); - }); -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Frontend/Python/Definitions/TensorFunctions.cpp b/TensorFrost/Frontend/Python/Definitions/TensorFunctions.cpp deleted file mode 100644 index 70637f84..00000000 --- a/TensorFrost/Frontend/Python/Definitions/TensorFunctions.cpp +++ /dev/null @@ -1,353 +0,0 @@ -#include -#include - -#include - -namespace TensorFrost { - -#define UNARY_FUNCTION(name) \ - m.def(#name, [](const PyTensor& t) { return PT(Tensor::name(T(t))); }) - -#define BINARY_FUNCTION(name) \ - m.def(#name, [](const PyTensor& t, const PyTensor& t2) { \ - return PT(Tensor::name(T(t), T(t2))); \ - }) - -#define TERNARY_FUNCTION(name) \ - m.def(#name, [](const PyTensor& t, const PyTensor& t2, const PyTensor& t3) { \ - return PT(Tensor::name(T(t), T(t2), T(t3))); \ - }) - -void TensorFunctionsDefinition(py::module& m) { - UNARY_FUNCTION(copy); - UNARY_FUNCTION(abs); - UNARY_FUNCTION(ceil); - UNARY_FUNCTION(floor); - UNARY_FUNCTION(round); - UNARY_FUNCTION(trunc); - UNARY_FUNCTION(sign); - UNARY_FUNCTION(frac); - UNARY_FUNCTION(sin); - UNARY_FUNCTION(cos); - UNARY_FUNCTION(tan); - UNARY_FUNCTION(asin); - UNARY_FUNCTION(acos); - UNARY_FUNCTION(atan); - UNARY_FUNCTION(sinh); - UNARY_FUNCTION(cosh); - UNARY_FUNCTION(tanh); - UNARY_FUNCTION(exp); - UNARY_FUNCTION(exp2); - UNARY_FUNCTION(log); - UNARY_FUNCTION(log2); - UNARY_FUNCTION(sqrt); - UNARY_FUNCTION(sqr); - UNARY_FUNCTION(rsqrt); - UNARY_FUNCTION(rcp); - - UNARY_FUNCTION(pcg); - UNARY_FUNCTION(pcgf); - UNARY_FUNCTION(reversebits); - - m.def("float", [](const PyTensor& t) { return PT(Tensor::tofloat(T(t))); }); - m.def("uint", [](const PyTensor& t) { return PT(Tensor::touint(T(t))); }); - m.def("int", [](const PyTensor& t) { return PT(Tensor::toint(T(t))); }); - m.def("bool", [](const PyTensor& t) { return PT(Tensor::tobool(T(t))); }); - - m.def("asfloat", [](const PyTensor& t) { return PT(Tensor::asfloat(T(t))); }); - m.def("asuint", [](const PyTensor& t) { return PT(Tensor::asuint(T(t))); }); - m.def("asint", [](const PyTensor& t) { return PT(Tensor::asint(T(t))); }); - - BINARY_FUNCTION(min); - BINARY_FUNCTION(max); - BINARY_FUNCTION(pow); - BINARY_FUNCTION(atan2); - BINARY_FUNCTION(modf); - - BINARY_FUNCTION(grad); - - TERNARY_FUNCTION(clamp); - TERNARY_FUNCTION(fma); - TERNARY_FUNCTION(lerp); - TERNARY_FUNCTION(select); - TERNARY_FUNCTION(smoothstep); - - m.def("scatterAdd", [](const PyTensor& t, const PyTensor& t2) { - Tensor::ScatterAdd(*t.Value(), T(t2), t.Indices()); - }); - - m.def("scatterAddPrev", [](const PyTensor& t, const PyTensor& t2) { - return PT(Tensor::ScatterAddPrev(*t.Value(), T(t2), t.Indices())); - }); - - m.def("scatterMin", [](const PyTensor& t, const PyTensor& t2) { - Tensor::ScatterMin(*t.Value(), T(t2), t.Indices()); - }); - - m.def("scatterMax", [](const PyTensor& t, const PyTensor& t2) { - Tensor::ScatterMax(*t.Value(), T(t2), t.Indices()); - }); - - m.def("scatterOr", [](const PyTensor& t, const PyTensor& t2) { - Tensor::ScatterOr(*t.Value(), T(t2), t.Indices()); - }); - - m.def("scatterAnd", [](const PyTensor& t, const PyTensor& t2) { - Tensor::ScatterAnd(*t.Value(), T(t2), t.Indices()); - }); - - m.def("scatterXor", [](const PyTensor& t, const PyTensor& t2) { - Tensor::ScatterXor(*t.Value(), T(t2), t.Indices()); - }); - - m.def("buffer", [](py::list shape, TFDataFormat type) { - return PT(Tensor::Memory(Reverse(TensorsFromList(shape)), type)); - }, py::arg("shape"), py::arg("type") = TFTypeFloat32); - m.def("buffer", [](std::vector shape, TFDataFormat type) { - return PT(Tensor::Memory(Reverse(shape), type)); - }, py::arg("shape"), py::arg("type") = TFTypeFloat32); - - m.def("local_buffer", [](int size, TFDataFormat type) { - return PT(Tensor::LocalMemory(size, type)); - }, py::arg("size"), py::arg("type") = TFTypeFloat32); - m.def("group_buffer", [](int size, TFDataFormat type) { - return PT(Tensor::GroupMemory(size, type)); - }, py::arg("size"), py::arg("type") = TFTypeFloat32); - m.def("group_barrier", []() { - Tensor::GroupBarrier(); - }); - - m.def("zeros", [](py::list shape, TFDataFormat type) { - return PT(Tensor::Constant(0u, Reverse(TensorsFromList(shape)), type)); - }, py::arg("shape"), py::arg("type") = TFTypeFloat32); - - m.def("const", [](float value, py::list shape) { - return PT(Tensor::Constant(Reverse(TensorsFromList(shape)), value)); - }); - m.def("const", [](float value, std::vector shape) { - return PT(Tensor::Constant(Reverse(shape), value)); - }, py::arg("value"), py::arg("shape") = std::vector{}); - - m.def("const", [](int value, py::list shape) { - return PT(Tensor::Constant(Reverse(TensorsFromList(shape)), value)); - }); - - m.def("const", [](int value, std::vector shape) { - return PT(Tensor::Constant(Reverse(shape), value)); - }, py::arg("value"), py::arg("shape") = std::vector{}); - - m.def("input", [](std::vector shape, TFDataFormat type) { - return PT(Tensor::Input(Reverse(shape), type)); - }, py::arg("shape"), py::arg("type") = TFTypeFloat32); - - m.def("input", [](py::list shape, TFDataFormat type) { - return PT(Tensor::Input(Reverse(TensorsFromList(shape)), type)); - }, py::arg("shape"), py::arg("type") = TFTypeFloat32); - - m.def("index", [](int dim, py::list shape) { - return PT(Tensor::Index(Reverse(TensorsFromList(shape)), dim)); - }); - - m.def("hash", [](py::list shape, const PyTensor& seed) { - return PT(Tensor::Hash(Reverse(TensorsFromList(shape)), T(seed))); - }, py::arg("shape"), py::arg("seed")); - - m.def("random_value", [](py::list shape, const PyTensor& seed) { - return PT(Tensor::Random(Reverse(TensorsFromList(shape)), T(seed))); - }, py::arg("shape"), py::arg("seed")); - - m.def("element_index", [](py::list shape) { - return PT(Tensor::ElementIndex(Reverse(TensorsFromList(shape)))); - }, py::arg("shape")); - - m.def("flat_index", [](py::list shape, py::list indices) { - Tensors shape_tensors = Reverse(TensorsFromList(shape)); - Tensors index_tensors = Reverse(TensorsFromList(indices)); - return PT(Tensor::FlatIndex(shape_tensors, index_tensors)); - }); - - m.def("indices_from_flat_index", [](const PyTensor& index, py::list shape) { - py::tuple indices = py::tuple(shape.size()); - Tensors shape_tensors = Reverse(TensorsFromList(shape)); - Tensors indices_tensors = Reverse(Tensor::IndicesFromFlatIndex(&T(index), shape_tensors)); - for (int i = 0; i < indices_tensors.size(); i++) { - indices[i] = PT(*indices_tensors[i]); - } - return indices; - }); - - m.def("get_copy", [](const PyTensor& t) { return PT(*Tensor::GetCopy(T(t))); }); - - m.def("indices", [](py::list shape) { - Tensors shape_tensors = Reverse(TensorsFromList(shape)); - int dim = (int)shape_tensors.size(); - py::tuple indices = py::tuple(shape_tensors.size()); - for (int i = 0; i < shape_tensors.size(); i++) { - auto t = PT(Tensor::Index(shape_tensors, dim - i - 1)); - indices[i] = t; - } - return indices; - }); - - m.def("indices", [](std::vector shape) { - py::tuple indices = py::tuple(shape.size()); - int dim = (int)shape.size(); - for (int i = 0; i < shape.size(); i++) { - auto t = PT(Tensor::Index(Reverse(shape), dim - i - 1)); - indices[i] = t; - } - return indices; - }); - - - m.def("index_grid", [](py::list begin, py::list end) { - Tensors begin_tensors = Reverse(TensorsFromList(begin)); - Tensors end_tensors = Reverse(TensorsFromList(end)); - Tensors index_grid = Reverse(Tensor::IndexGrid(begin_tensors, end_tensors)); - - py::tuple indices = py::tuple(begin.size()); - for (int i = 0; i < index_grid.size(); i++) { - indices[i] = PT(*index_grid[i]); - } - return indices; - }); - - m.def("index_grid", [](py::list begin, py::list end, py::list step) { - Tensors begin_tensors = Reverse(TensorsFromList(begin)); - Tensors end_tensors = Reverse(TensorsFromList(end)); - Tensors step_tensors = Reverse(TensorsFromList(step)); - Tensors index_grid = Reverse(Tensor::IndexGrid(begin_tensors, end_tensors, step_tensors)); - - py::tuple indices = py::tuple(begin.size()); - for (int i = 0; i < index_grid.size(); i++) { - indices[i] = PT(*index_grid[i]); - } - return indices; - }); - - m.def("reshape", [](const PyTensor& t, py::list shape, TFDataFormat type) { - return PT(Tensor::Reshape(T(t), Reverse(TensorsFromList(shape)), type)); - }, py::arg("t"), py::arg("shape"), py::arg("type") = TFTypeNone); - - m.def("assert_tensor", [](const PyTensor& t, py::list target_shape, TFDataFormat target_type) { - return PT(Tensor::Assert(T(t), Reverse(TensorsFromList(target_shape)), target_type)); - }); - m.def("split_dim", [](const PyTensor& t, const int split_size, const int axis) { - return PT(Tensor::SplitDim(T(t), split_size, -axis-1)); - }, py::arg("t"), py::arg("split_size"), py::arg("axis") = -1); - m.def("merge_dim", [](const PyTensor& t, const int axis, const PyTensor* target_size) { - const Tensor* target_size_ptr = target_size ? &T(*target_size) : nullptr; - return PT(Tensor::MergeDim(T(t), -axis-1, target_size_ptr)); - }, py::arg("t"), py::arg("axis") = -1, py::arg("target_size") = nullptr); - m.def("repeat", [](const PyTensor& t, const PyTensor& repeats, const int axis) { - return PT(Tensor::Repeat(T(t), T(repeats), -axis-1)); - }, py::arg("t"), py::arg("repeats"), py::arg("axis") = -1); - - //algorithm functions - m.def("sum", [](const PyTensor& t, const int axis) { return PT(Tensor::Sum(T(t), -axis-1)); }, - py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Sum the elements of the tensor along the axis"); - - m.def("norm", [](const PyTensor& t, const int axis) { return PT(Tensor::Norm(T(t), -axis-1)); }, - py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Compute the norm of the tensor along the axis"); - - m.def("mean", [](const PyTensor& t, const int axis) { return PT(Tensor::Mean(T(t), -axis-1)); }, - py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Compute the mean of the tensor along the axis"); - - m.def("min", [](const PyTensor& t, const int axis) { return PT(Tensor::Min(T(t), -axis-1)); }, - py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Compute the min of the tensor along the axis"); - - m.def("max", [](const PyTensor& t, const int axis) { return PT(Tensor::Max(T(t), -axis-1)); }, - py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Compute the max of the tensor along the axis"); - - m.def("any", [](const PyTensor& t, const int axis) { return PT(Tensor::Any(T(t), -axis-1)); }, - py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Do an OR operation along the axis"); - - m.def("all", [](const PyTensor& t, const int axis) { return PT(Tensor::All(T(t), -axis-1)); }, - py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Do an AND operation along the axis"); - - m.def("prefix_sum", [](const PyTensor& t, const int axis) { return PT(Tensor::PrefixSum(T(t), -axis-1)); }, - py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Compute the prefix sum of the tensor along the axis"); - - m.def("reverse", [](const PyTensor& t, const int axis) { return PT(Tensor::Reverse(T(t), -axis-1)); }, - py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Reverse the tensor along the axis"); - - m.def("transpose", [](const PyTensor& t, int dim1, int dim2) { - return PT(Tensor::Transpose(T(t), -dim1-1, -dim2-1)); - }, py::arg("t"), py::kw_only(), py::arg("dim1") = -2, py::arg("dim2") = -1, "Transpose the tensor"); - - m.def("unsqueeze", [](const PyTensor& t, int dim) { - return PT(Tensor::Unsqueeze(T(t), -dim-1)); - }, py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Unsqueeze the tensor"); - - m.def("squeeze", [](const PyTensor& t, int dim) { - return PT(Tensor::Squeeze(T(t), -dim-1)); - }, py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Squeeze the tensor"); - - m.def("dot", [](const PyTensor& t, const PyTensor& t2, int axis) { - return PT(Tensor::Dot(T(t), T(t2), -axis-1)); - }, py::arg("t"), py::arg("t2"), py::kw_only(), py::arg("axis") = -1, "Dot product of two tensors"); - - m.def("matmul", [](const PyTensor& t, const PyTensor& t2) { - return PT(Tensor::Matmul(T(t), T(t2))); - }, py::arg("t"), py::arg("t2"), "Matrix multiplication of two tensors"); - - m.def("region_begin", [](const std::string& name) { - Tensor::BeginRegion(name); - }, py::arg("name"), "Begin a debug region"); - - m.def("region_end", [](const std::string& name) { - Tensor::EndRegion(name); - }, py::arg("name"), "End a debug region"); - - m.def("register_custom_operation", [](const std::string& name, vector overloads, py::function impl, py::function vjp) { - auto cpp_impl = [impl](Tensors& output, map inputs, const Tensor* tensor, vector axes) { - py::list input_list; - for (auto& [id, tensor] : inputs) { - input_list.append(PT(*tensor)); - } - py::list output_list = impl(input_list, PT(*tensor), axes).cast(); - for (int i = 0; i < output_list.size(); i++) { - PyTensor* t = output_list[i].cast(); - output.push_back(&t->Get()); - } - }; - - auto cpp_vjp = [vjp](map inputs, const Tensor* gradient, const Tensor* tensor) { - py::list input_list; - for (auto& [id, tensor] : inputs) { - input_list.append(PT(*tensor)); - } - py::list output_list = vjp(input_list, PT(*gradient), PT(*tensor)).cast(); - Tensors gradients; - for (int i = 0; i < output_list.size(); i++) { - PyTensor* t = output_list[i].cast(); - gradients.push_back(&t->Get()); - } - return gradients; - }; - - RegisterAlgorithmicPrimitive(name, overloads, cpp_impl, cpp_vjp); - }, py::arg("name"), py::arg("overloads"), py::arg("impl"), py::arg("vjp"), "Register a custom operation"); - - m.def("custom", [](const std::string& name, py::list inputs, py::list shape) { - Tensors input_tensors = TensorsFromList(inputs); - Tensors shape_tensors = Reverse(TensorsFromList(shape)); - return PT(Tensor::CustomOperation(name, input_tensors, shape_tensors)); - }, py::arg("name"), py::arg("inputs"), py::arg("shape"), "Run custom operation"); - - m.def("custom", [](const std::string& name, py::list inputs) { - Tensors input_tensors = TensorsFromList(inputs); - Tensors shape_tensors = input_tensors[0]->GetShape(); - return PT(Tensor::CustomOperation(name, input_tensors, shape_tensors)); - }, py::arg("name"), py::arg("inputs"), "Run custom operation"); - - m.def("print_value", [](const std::string& name, const PyTensor& t) { - Tensor::PrintValue(name, T(t)); - }, py::arg("name"), py::arg("t"), "Print the value of the tensor"); - - m.def("assert_value", [](const std::string& name, const PyTensor& t) { - Tensor::AssertValue(name, T(t)); - }, py::arg("name"), py::arg("t"), "Assert the value of the tensor"); -} - -} // namespace TensorFrost diff --git a/TensorFrost/Frontend/Python/Definitions/TensorMemory.cpp b/TensorFrost/Frontend/Python/Definitions/TensorMemory.cpp deleted file mode 100644 index d3b96169..00000000 --- a/TensorFrost/Frontend/Python/Definitions/TensorMemory.cpp +++ /dev/null @@ -1,72 +0,0 @@ -#include -#include - -#include -#include - -namespace TensorFrost { - -void TensorMemoryDefinition(py::module& m, - py::class_& py_tensor_mem) { - //define constructors from numpy arrays - py_tensor_mem.def(py::init([](py::array arr) { - return PyTensorMemory(arr); - }), "Create a TensorMemory from a numpy array", py::return_value_policy::take_ownership); - - // "constructor" - m.def( - "tensor", - [](const std::vector& shape, TFDataFormat type) { - return PyTensorMemory(shape, type); - },"Create a TensorMemory with the given shape", py::return_value_policy::take_ownership); - - // "constructor" from numpy array - m.def( - "tensor", - [](py::array arr) { - return new PyTensorMemory(arr); - }, - "Create a TensorMemory from a numpy array", py::return_value_policy::take_ownership); - - // properties - py_tensor_mem.def_property_readonly("shape", [](const PyTensorMemory& t) { - vector shape = t.Shape(); - return py::cast(shape); - }); - - py_tensor_mem.def_property_readonly("type", [](const PyTensorMemory& t) { - return t.GetFormat(); - }); - - py_tensor_mem.def_property_readonly("size", [](const PyTensorMemory& t) { - return GetSize(t.tensor_); - }); - - // to numpy array - py_tensor_mem.def_property_readonly( - "numpy", - [](const PyTensorMemory& t) - -> std::variant, py::array_t, - py::array_t, py::array_t> { - if (t.GetFormat() == TFTypeFloat32) { - return t.ToPyArray(); - } else if (t.GetFormat() == TFTypeInt32) { - return t.ToPyArray(); - } else if (t.GetFormat() == TFTypeUint32) { - return t.ToPyArray(); - } else if (t.GetFormat() == TFTypeBool32) { - return t.ToPyArray(); - } else { - throw std::runtime_error("Unsupported data type for numpy conversion"); - } - }, - "Readback data from tensor memory to a numpy array", py::return_value_policy::take_ownership); - - m.def("allocated_memory", []() { return global_memory_manager->GetAllocatedSize(); }, - "Get the amount of memory currently used by the memory manager"); - - m.def("unused_allocated_memory", []() { return global_memory_manager->GetUnusedAllocatedSize(); }, - "Get the amount of memory currently allocated but not used by the memory manager"); -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Frontend/Python/Definitions/TensorProgram.cpp b/TensorFrost/Frontend/Python/Definitions/TensorProgram.cpp deleted file mode 100644 index f84cdc21..00000000 --- a/TensorFrost/Frontend/Python/Definitions/TensorProgram.cpp +++ /dev/null @@ -1,199 +0,0 @@ -#include -#include - -#include -#include -#include - -namespace TensorFrost { - -void TensorProgramDefinition(py::module& m, - py::class_& tensor_program) { - m.def( - "compile", - [](const py::function& py_evaluate) { - // Extract the name of the Python function - std::string func_name = - py_evaluate.attr("__name__").cast(); - - vector inputs = GetFunctionArguments(py_evaluate); - vector arg_names; - vector arg_props; - for (auto arg : inputs) { - arg_names.push_back(std::get<0>(arg)); - py::object arg_prop = std::get<1>(arg); - py::object arg_default = std::get<2>(arg); - if (py::isinstance(arg_prop)) { - PyTensorArg arg_tensor = arg_prop.cast(); - arg_props.push_back(arg_tensor); - } else { - throw std::runtime_error("Unsupported input type " + std::string(py::str(arg_prop))); - } - } - - TensorProgram& program = *new TensorProgram( - [py_evaluate, arg_names, arg_props]() -> Tensors { - py::gil_scoped_acquire acquire; - std::vector args; - //create inputs from the arguments - for (size_t i = 0; i < arg_names.size(); i++) { - Tensor& input = Tensor::Input(arg_props[i].shape, arg_props[i].type); - input.SetDebugName(arg_names[i]); - PyTensor* py_tensor = new PyTensor(&input); - args.push_back(py_tensor); - } - //convert to py::args - py::args py_args = py::cast(args); - py::object result = py_evaluate(*py_args); - Tensors outputs; - //if the result is a single tensor - if (py::isinstance(result)) { - outputs.push_back(&py::cast(result).Get()); - } else { - auto py_outputs = py::cast>(result); - for (PyTensor output : py_outputs) { - outputs.push_back(&output.Get()); - } - } - return outputs; - }, - func_name); - - py::print(program.PrintProperties()); - return &program; - }, - "Compile a TensorProgram from a python function"); - - tensor_program.def( - "__call__", - [](TensorProgram& program, py::args py_inputs) -> std::variant { - vector inputs_props; - vector temp_numpy_tensors; - for (auto arg : py_inputs) { - if (py::isinstance(arg)) { //if just tensor memory - PyTensorMemory* mem = &arg.cast(); - inputs_props.push_back(arg.cast()); - } else if (py::isinstance(arg)) { //if module then add its parameters - Module* module = &arg.cast(); - py::list params = module->parameters(); - for (auto param : params) { - inputs_props.push_back(param.cast()); - } - } else if (py::isinstance(arg)) { //if numpy array then create pytensormemory from it and add it - py::array arr = arg.cast(); - PyTensorMemory* temp_tensor = new PyTensorMemory(arr); - inputs_props.push_back(py::cast(temp_tensor, py::return_value_policy::take_ownership)); - temp_numpy_tensors.push_back(temp_tensor->tensor_); - } else if (py::isinstance(arg)) { //if list then convert to py::array then create pytensormemory from it and add it - py::array arr = ListToArray(arg.cast()); - PyTensorMemory* temp_tensor = new PyTensorMemory(arr); - inputs_props.push_back(py::cast(temp_tensor, py::return_value_policy::take_ownership)); - temp_numpy_tensors.push_back(temp_tensor->tensor_); - } else { - throw std::runtime_error("Unsupported input type " + std::string(py::str(arg))); - } - } - - vector inputs; - for (auto input : inputs_props) { - PyTensorMemory* mem = input.cast(); - inputs.push_back(mem->tensor_); - } - vector outputs = program.Evaluate(inputs); - - //remove temporary tensors if they are not in the outputs - for (TFTensor* temp_tensor : temp_numpy_tensors) { - bool found = false; - for (TFTensor* output : outputs) { - if (temp_tensor->buffer == output->buffer) { - found = true; - break; - } - } - if (!found) { - global_memory_manager->DeallocateTensor(*temp_tensor); - } - } - - vector output_tensors; - for (size_t i = 0; i < outputs.size(); i++) { - //if any of the outputs are also inputs, then replace them with the input tensors - TFTensor* out = outputs[i]; - bool is_input = false; - for (size_t j = 0; j < inputs_props.size(); j++) { - PyTensorMemory* in = inputs_props[j].cast(); - if (out->buffer == in->tensor_->buffer) { - output_tensors.push_back(inputs_props[j]); - is_input = true; - break; - } - } - if (is_input) { - continue; - } - //otherwise create a new tensor memory - output_tensors.push_back(py::cast(new PyTensorMemory(outputs[i]), py::return_value_policy::take_ownership)); - } - - //if there is only one output, return the tensor memory - if (outputs.size() == 1) { - return output_tensors[0]; - } else { - //convert to py::tuple of PyTensorMemory* - py::tuple py_outputs = py::tuple(outputs.size()); - for (size_t i = 0; i < outputs.size(); i++) { - py_outputs[i] = output_tensors[i]; - } - return py_outputs; - } - }, - "Evaluate the TensorProgram with the given inputs"); - - tensor_program.def( - "list_operations", - [](TensorProgram& program, bool compact) { - std::string listing = "List of operations:\n"; - listing += GetOperationListing(program.ir, compact); - return py::str(listing); - }, - py::arg("compact") = true); - - tensor_program.def("compiled_code", [](TensorProgram& program) { - string code = program.program->generated_code_; - return py::str(code); - }); - - tensor_program.def("get_kernels", [](TensorProgram& program) { - vector kernel_source; - for (auto& kernel : program.program->kernels_) { - kernel_source.push_back(kernel.full_generated_code_); - } - return kernel_source; - }); - - tensor_program.def("get_main_function", [](TensorProgram& program) { - return program.program->main_function_; - }); - - tensor_program.def("get_last_execution_time", [](TensorProgram& program) { - return program.program->last_execution_time; - }); - - m.def("get_all_generated_main_functions", []() { - return global_kernel_manager->GetAllMainFunctions(); - }); - - m.def("get_all_generated_kernels", []() { - return global_kernel_manager->GetAllKernels(); - }); - - m.def("get_cpp_header", []() { - return GetCPPHeader(); - }); - - m.def("get_cpp_implementation", []() { - return GetCPPImplementation(); - }); -} - -} // namespace TensorFrost diff --git a/TensorFrost/Frontend/Python/Definitions/TensorScope.cpp b/TensorFrost/Frontend/Python/Definitions/TensorScope.cpp deleted file mode 100644 index 886bb5a4..00000000 --- a/TensorFrost/Frontend/Python/Definitions/TensorScope.cpp +++ /dev/null @@ -1,118 +0,0 @@ -#include -#include - -#include - -namespace TensorFrost { - -void ScopeDefinitions(py::module& m, py::class_& py_tensor) { - m.def( - "loop", - [](const py::function& body, const PyTensor& begin, const PyTensor& end, - const PyTensor& step) { - // wrap the function to convert the PyTensor to Tensor - std::function f2 = [&body](const Tensor& t) { - py::gil_scoped_acquire acquire; - body(PT(t)); - }; - - Tensor::Loop(T(begin), T(end), T(step), f2); - }, - py::arg("begin") = 0, py::arg("end"), py::arg("step") = 1, - py::arg("body")); - - m.def( - "if_cond", - [](const PyTensor& condition, const py::function& true_body) { - std::function f = [&true_body]() { - py::gil_scoped_acquire acquire; - true_body(); - }; - Tensor::If(T(condition), f); - }, - py::arg("condition"), py::arg("true_body")); - - m.def( - "if_cond", - [](const PyTensor& condition, const py::function& true_body, - const py::function& false_body) { - std::function f1 = [&true_body]() { - py::gil_scoped_acquire acquire; - true_body(); - }; - std::function f2 = [&false_body]() { - py::gil_scoped_acquire acquire; - false_body(); - }; - Tensor::If(T(condition), f1, f2); - }, - py::arg("condition"), py::arg("true_body"), py::arg("false_body")); - - m.def("break_loop", []() { Tensor::Break(); }); - m.def("continue_loop", []() { Tensor::Continue(); }); - - m.def( - "kernel", - [](py::list shape, const py::function& body) { - // wrap the function to convert the PyTensor to Tensor - std::function f2 = - [&body](const Tensors& tensors) { - py::gil_scoped_acquire acquire; - PyTensors py_tensors = PyTensorsFromTensors(tensors); - body(py_tensors); - }; - - Tensor::Kernel(Reverse(TensorsFromList(shape)), f2); - }, - py::arg("shape"), py::arg("body")); - - // m.def( - // "vmap", - // [](py::list inputs, py::list shape, const py::function& func) { - // std::function f = [&func]() { - // py::gil_scoped_acquire acquire; - // func(); - // }; - // Tensor::If(T(condition), f); - // }, - // py::arg("condition"), py::arg("true_body")); - - py_tensor.def("__enter__", &PyTensor::__enter__); - py_tensor.def("__exit__", &PyTensor::__exit__); - - //loop scope - m.def("loop", - [](const PyTensor& begin, const PyTensor& end, const PyTensor& step) { - Tensor& for_loop = Tensor::Loop(T(begin), T(end), T(step)); - return PT(for_loop); - }); - - m.def("loop", - [](const PyTensor& begin, const PyTensor& end) { - Tensor& for_loop = Tensor::Loop(T(begin), T(end), T(PyTensor(1))); - return PT(for_loop); - }); - - m.def("loop", - [](const PyTensor& end) { - Tensor& for_loop = Tensor::Loop(T(PyTensor(0)), T(end), T(PyTensor(1))); - return PT(for_loop); - }); - - //if scope - m.def("if_cond", - [](const PyTensor& condition) { - Tensor& if_cond = Tensor::If(T(condition)); - return PT(if_cond); - }); - - //kernel scope - m.def("kernel", - [](py::list shape, vector group_size) { - Tensors shape_tensors = Reverse(TensorsFromList(shape)); - Tensor& kernel = Tensor::Kernel(shape_tensors, group_size); - return PT(kernel); - }, py::arg("shape"), py::arg("group_size") = vector()); -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Frontend/Python/Definitions/WindowUtils.cpp b/TensorFrost/Frontend/Python/Definitions/WindowUtils.cpp deleted file mode 100644 index bc7b3ad6..00000000 --- a/TensorFrost/Frontend/Python/Definitions/WindowUtils.cpp +++ /dev/null @@ -1,148 +0,0 @@ -#include -#include - -#include -#include - -#include "Backend/RenderDoc.h" - -namespace TensorFrost { - -void WindowDefinitions(py::module& m) { - py::module window = m.def_submodule("window", "Window functions"); - - window.def( - "show", - [](int width, int height, string title) { - ShowWindow(width, height, title.c_str()); - }, - "Show the memory manager window"); - - window.def( - "hide", []() { HideWindow(); }, "Hide the memory manager window"); - - window.def( - "render_frame", [](const PyTensorMemory& t) { RenderFrame(t.tensor_); }, - "Render a frame from the tensor memory"); - - window.def("render_frame", []() { RenderFrame(nullptr); }, - "Render an empty frame"); - - window.def( - "should_close", []() { return WindowShouldClose(); }, - "Check if the window should close"); - - window.def( - "get_mouse_position", []() { return GetMousePosition(); }, - "Get the current mouse position"); - - window.def( - "get_size", []() { return GetWindowSize(); }, - "Get the current window size"); - - window.def( - "is_mouse_button_pressed", - [](int button) { return IsMouseButtonPressed(button); }, - "Check if a mouse button is pressed"); - - window.def( - "is_key_pressed", [](int key) { return IsKeyPressed(key); }, - "Check if a key is pressed"); - - window.attr("MOUSE_BUTTON_0") = GLFW_MOUSE_BUTTON_1; - window.attr("MOUSE_BUTTON_1") = GLFW_MOUSE_BUTTON_2; - window.attr("MOUSE_BUTTON_2") = GLFW_MOUSE_BUTTON_3; - - window.attr("KEY_SPACE") = GLFW_KEY_SPACE; - window.attr("KEY_APOSTROPHE") = GLFW_KEY_APOSTROPHE; - window.attr("KEY_COMMA") = GLFW_KEY_COMMA; - window.attr("KEY_MINUS") = GLFW_KEY_MINUS; - window.attr("KEY_PERIOD") = GLFW_KEY_PERIOD; - - window.attr("KEY_A") = GLFW_KEY_A; - window.attr("KEY_B") = GLFW_KEY_B; - window.attr("KEY_C") = GLFW_KEY_C; - window.attr("KEY_D") = GLFW_KEY_D; - window.attr("KEY_E") = GLFW_KEY_E; - window.attr("KEY_F") = GLFW_KEY_F; - window.attr("KEY_G") = GLFW_KEY_G; - window.attr("KEY_H") = GLFW_KEY_H; - window.attr("KEY_I") = GLFW_KEY_I; - window.attr("KEY_J") = GLFW_KEY_J; - window.attr("KEY_K") = GLFW_KEY_K; - window.attr("KEY_L") = GLFW_KEY_L; - window.attr("KEY_M") = GLFW_KEY_M; - window.attr("KEY_N") = GLFW_KEY_N; - window.attr("KEY_O") = GLFW_KEY_O; - window.attr("KEY_P") = GLFW_KEY_P; - window.attr("KEY_Q") = GLFW_KEY_Q; - window.attr("KEY_R") = GLFW_KEY_R; - window.attr("KEY_S") = GLFW_KEY_S; - window.attr("KEY_T") = GLFW_KEY_T; - window.attr("KEY_U") = GLFW_KEY_U; - window.attr("KEY_V") = GLFW_KEY_V; - window.attr("KEY_W") = GLFW_KEY_W; - window.attr("KEY_X") = GLFW_KEY_X; - window.attr("KEY_Y") = GLFW_KEY_Y; - window.attr("KEY_Z") = GLFW_KEY_Z; - - py::module imgui = m.def_submodule("imgui", "ImGui functions"); - - imgui.def( - "begin", [](string name) { ImGuiBegin(name); }, - "Begin a new ImGui window"); - - imgui.def("end", []() { ImGuiEnd(); }, "End the current ImGui window"); - - imgui.def( - "text", [](string text) { ImGuiText(text); }, - "Add text to the current ImGui window"); - - imgui.def("slider", - [](string text, int* value, int min, int max) { - ImGuiSlider(text, value, min, max); - return *value; - }, "Add a slider to the current ImGui window"); - - imgui.def("slider", [](string text, float* value, float min, float max) { - ImGuiSlider(text, value, min, max); - return *value; - }, "Add a slider to the current ImGui window"); - - imgui.def("button", [](string text) { return ImGuiButton(text); }, - "Add a button to the current ImGui window"); - - imgui.def("checkbox", [](string text, bool* value) { - ImGuiCheckbox(text, value); - return *value; - }, "Add a checkbox to the current ImGui window"); - - imgui.def("plotlines", [](string label, py::array_t values, int values_offset, string overlay_text, float scale_min, float scale_max, py::tuple graph_size, int stride) { - ImGuiPlotLines(label.c_str(), values.data(), (int)values.size(), values_offset, overlay_text.c_str(), scale_min, scale_max, ImVec2(graph_size[0].cast(), graph_size[1].cast()), stride); - }, py::arg("label"), py::arg("values"), py::arg("values_offset") = 0, py::arg("overlay_text") = "", py::arg("scale_min") = FLT_MAX, py::arg("scale_max") = FLT_MAX, py::arg("graph_size") = py::make_tuple(0.0f, 0.0f), py::arg("stride") = sizeof(float)); - - imgui.def("scale_all_sizes", [](float scale) { ImGuiScaleAllSizes(scale); }, - "Scale all ImGui sizes by a factor"); - - imgui.def("add_background_text", [](string text, py::tuple pos, py::tuple color) { - ImGuiAddBackgroundText(text, ImVec2(pos[0].cast(), pos[1].cast()), ImVec4(color[0].cast(), color[1].cast(), color[2].cast(), color[3].cast())); - }, py::arg("text"), py::arg("pos"), py::arg("color")); - - imgui.def("color_picker3", [](string text, py::array_t color) { - ImGuiColorPicker3(text, color.mutable_data()); - return color; - }, py::arg("text"), py::arg("color")); - - imgui.def("color_picker4", [](string text, py::array_t color) { - ImGuiColorPicker4(text, color.mutable_data()); - return color; - }, py::arg("text"), py::arg("color")); - - //renderdoc - m.def("renderdoc_start_capture", []() { StartRenderDocCapture(); }, - "Start a RenderDoc capture"); - m.def("renderdoc_end_capture", []() { EndRenderDocCapture(); }, - "End a RenderDoc capture"); -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Frontend/Python/PyModule.h b/TensorFrost/Frontend/Python/PyModule.h deleted file mode 100644 index 87d2c7d4..00000000 --- a/TensorFrost/Frontend/Python/PyModule.h +++ /dev/null @@ -1,430 +0,0 @@ -#include -#include -#include -#include -#include -#include - -namespace TensorFrost { - -namespace py = pybind11; - -class Parameter { -public: - std::vector shape; - TFDataFormat dtype; - float random_scale; - float random_offset; - string debug_name; - bool optimize; - - Parameter(const std::vector& shape, TFDataFormat dtype, float random_scale = -1.0f, float random_offset = 0.0f, bool requires_grad = true) - : shape(shape), dtype(dtype), random_scale(random_scale), random_offset(random_offset), optimize(requires_grad), debug_name("") {} - - bool CanBeInitialized() { - for (int i = 0; i < shape.size(); i++) { - if (shape[i] == -1) { - return false; - } - } - return true; - } -}; - -class ParameterArray { -public: - map _parameters; //must be sorted - map _requires_grad; - - py::object getitem(size_t index) { - if (_parameters.contains(index)) { - return _parameters[index]; - } - throw py::index_error("Index " + std::to_string(index) + " is not in the ParameterArray"); - } - - void setitem(size_t index, py::object value) { - _parameters[index] = value; - if (py::isinstance(value)) { - _requires_grad[index] = py::cast(value).optimize; - } - } - - vector> items() { - vector> params; - for (auto& param : _parameters) { - params.push_back({param.first, param.second}); - } - return params; - } -}; - -class Module { -public: - enum class AttributeType { - None, - Parameter, - ParameterArray, - Module - }; - - map _attributes; - map _attribute_types; - map _optimize; - vector _attribute_order; - bool optimize = true; - - py::object tf; - - Module(bool requires_grad = true) : optimize(requires_grad) { - tf = py::module::import("TensorFrost"); - } - - py::object getattr(const std::string& name) { - if (_attributes.contains(name)) { - return _attributes[name]; - } - throw py::attribute_error("TensorFrost Module object has no attribute with name '" + name + "'"); - } - - void setattr(const std::string& name, py::object value) { - AttributeType type = AttributeType::None; - bool optimize = true; - if (py::isinstance(value)) { - type = AttributeType::Parameter; - optimize = py::cast(value).optimize; - //if not float then can't optimize - if(py::cast(value).dtype.type != TFType::Float) { - optimize = false; - } - } else if (py::isinstance(value)) { - type = AttributeType::ParameterArray; - } else if (py::isinstance(value)) { - type = AttributeType::Module; - optimize = py::cast(value).optimize; - } else if (py::isinstance(value)) { - PyTensor& tensor = py::cast(value); - tensor.Get().node_->debug_name = name; - } - - bool already_exists = _attributes.contains(name); - if(type == AttributeType::None && already_exists) { - type = _attribute_types[name]; - optimize = _optimize[name]; - } - - _attributes[name] = value; - _attribute_types[name] = type; - _optimize[name] = optimize && this->optimize; - if (!already_exists) _attribute_order.push_back(name); - } - - bool hasattr(const std::string& name) { - return _attributes.contains(name); - } - - bool param_requires_grad(const std::string& name) { - if (_optimize.contains(name)) { - return _optimize[name]; - } - return optimize; - } - - vector> get_attributes_of_type(AttributeType type) { - vector> params; - for (auto& attr : _attribute_order) { - if (_attribute_types[attr] == type) { - params.push_back({attr, _attributes[attr]}); - } - } - return params; - } - - virtual void assert_parameters() {} - - py::object initialize_parameter_native(Parameter& param) { - //Convert to list - vector shape_list; - for (int i = 0; i < param.shape.size(); i++) { - shape_list.push_back(tf.attr("const")(param.shape[i])); - } - py::object result; - if(param.dtype.type == TFType::Float) { - float shape_sum = 0.0f; - for (int i = 0; i < param.shape.size(); i++) { - shape_sum += (float)param.shape[i]; - } - float scale = sqrt(6.0f / shape_sum); - if(param.random_scale >= 0.0f) { - scale = param.random_scale; - } - result = tf.attr("random_value")(shape_list, tf.attr("const")(0u)); - result = result.attr("__mul__")(py::float_(2.0f)); - result = result.attr("__sub__")(py::float_(1.0f)); - result = result.attr("__mul__")(py::float_(scale)); - result = result.attr("__add__")(py::float_(param.random_offset)); - } else if (param.dtype.type == TFType::Int) { //just use zeros - result = tf.attr("const")(0, shape_list); - } else if (param.dtype.type == TFType::Uint) { //just use zeros - result = tf.attr("const")(0, shape_list); - } else { //just use zeros - result = tf.attr("const")(false, shape_list); - } - if(param.debug_name != "") { - result = result.attr("set_debug_name")(param.debug_name); - } - return result; - } - - void initialize_parameters_native() { - for (auto& module : get_attributes_of_type(AttributeType::Module)) { - module.second.attr("initialize_parameters_native")(); - } - - for (auto& param : get_attributes_of_type(AttributeType::Parameter)) { - if(!py::isinstance(param.second)) { - continue; - } - - Parameter& p = py::cast(param.second); - if (!p.CanBeInitialized()) { - //replace all -1 shapes with 1 - for (int i = 0; i < p.shape.size(); i++) { - if (p.shape[i] == -1) { - p.shape[i] = 1; - } - } - } - py::object tensor = initialize_parameter_native(p); - setattr(param.first, tensor); - } - - for (auto& array : get_attributes_of_type(AttributeType::ParameterArray)) { - ParameterArray& param_array = py::cast(array.second); - for (auto& param : param_array._parameters) { - if(!py::isinstance(param.second)) { - continue; - } - Parameter& p = py::cast(param.second); - if (!p.CanBeInitialized()) { - continue; - } - py::object tensor = initialize_parameter_native(p); - param_array.setitem(param.first, tensor); - } - } - } - - void initialize_input() { - for (auto& module : get_attributes_of_type(AttributeType::Module)) { - module.second.attr("initialize_input")(); - } - - for (auto& param : get_attributes_of_type(AttributeType::Parameter)) { - Parameter& p = py::cast(param.second); - py::object tensor = tf.attr("input")(p.shape, p.dtype); - if(p.debug_name != "") { - tensor = tensor.attr("set_debug_name")(p.debug_name); - } - setattr(param.first, tensor); - } - - for (auto& array : get_attributes_of_type(AttributeType::ParameterArray)) { - ParameterArray& param_array = py::cast(array.second); - for (auto& param : param_array._parameters) { - Parameter& p = py::cast(param.second); - py::object tensor = tf.attr("input")(p.shape, p.dtype); - if(p.debug_name != "") { - tensor = tensor.attr("set_debug_name")(p.debug_name); - } - param_array.setitem(param.first, tensor); - } - } - - assert_parameters(); - } - - py::object initialize_parameter(Parameter& param) { - py::object np = py::module::import("numpy"); - py::object random = np.attr("random"); - - // Convert the shape vector to a tuple - py::tuple shape_tuple = py::cast(param.shape); - - if(param.dtype.type == TFType::Float) { - // Generate uniform random values instead of normal - py::array_t arr = random.attr("uniform")(-1.0f, 1.0f, shape_tuple).cast>(); - float shape_sum = 0.0f; - for (int i = 0; i < param.shape.size(); i++) { - shape_sum += (float)param.shape[i]; - } - float scale = sqrt(6.0f / shape_sum); - if(param.random_scale >= 0.0f) { - scale = param.random_scale; - } - - arr = arr.attr("__mul__")(py::float_(scale)); - arr = arr.attr("__add__")(py::float_(param.random_offset)); - return tf.attr("tensor")(arr); - } else if (param.dtype.type == TFType::Int) { //just use zeros - py::array_t arr = np.attr("zeros")(shape_tuple).cast>(); - return tf.attr("tensor")(arr); - } else if (param.dtype.type == TFType::Uint) { //just use zeros - py::array_t arr = np.attr("zeros")(shape_tuple).cast>(); - return tf.attr("tensor")(arr); - } else { //just use zeros - py::array_t arr = np.attr("zeros")(shape_tuple).cast>(); - return tf.attr("tensor")(arr); - } - } - - void initialize_parameters() { - for (auto& module : get_attributes_of_type(AttributeType::Module)) { - module.second.attr("initialize_parameters")(); - } - - for (auto& param : get_attributes_of_type(AttributeType::Parameter)) { - if(!py::isinstance(param.second)) { - continue; - } - - Parameter& p = py::cast(param.second); - if (!p.CanBeInitialized()) { - continue; - } - py::object tensor = initialize_parameter(p); - setattr(param.first, tensor); - } - - for (auto& array : get_attributes_of_type(AttributeType::ParameterArray)) { - ParameterArray& param_array = py::cast(array.second); - for (auto& param : param_array._parameters) { - if(!py::isinstance(param.second)) { - continue; - } - Parameter& p = py::cast(param.second); - if (!p.CanBeInitialized()) { - continue; - } - py::object tensor = initialize_parameter(p); - param_array.setitem(param.first, tensor); - } - } - } - - py::list parameters() { - py::list params; - for (auto& module : get_attributes_of_type(AttributeType::Module)) { - params += module.second.attr("parameters")(); - } - - for (auto& param : get_attributes_of_type(AttributeType::Parameter)) { - params.append(param.second); - } - - for (auto& array : get_attributes_of_type(AttributeType::ParameterArray)) { - ParameterArray& param_array = py::cast(array.second); - for (auto& param : param_array._parameters) { - params.append(param.second); - } - } - return params; - } - - py::list named_parameters() { - py::list params; - for (auto& module : get_attributes_of_type(AttributeType::Module)) { - params += module.second.attr("named_parameters")(); - } - - for (auto& param : get_attributes_of_type(AttributeType::Parameter)) { - params.append(py::make_tuple(param.first, param.second)); - } - - for (auto& array : get_attributes_of_type(AttributeType::ParameterArray)) { - ParameterArray& param_array = py::cast(array.second); - for (auto& param : param_array._parameters) { - params.append(py::make_tuple(array.first + "[" + std::to_string(param.first) + "]", param.second)); - } - } - return params; - } - - py::list requires_grads_list() { - py::list requires_grads; - for (auto& module : get_attributes_of_type(AttributeType::Module)) { - requires_grads.append( param_requires_grad(module.first) ); - } - - for (auto& param : get_attributes_of_type(AttributeType::Parameter)) { - requires_grads.append( param_requires_grad(param.first) ); - } - - for (auto& array : get_attributes_of_type(AttributeType::ParameterArray)) { - ParameterArray& param_array = py::cast(array.second); - bool requires_grad = param_requires_grad(array.first); - for (auto& param : param_array._parameters) { - requires_grads.append( param_array._requires_grad[param.first] && requires_grad ); - } - } - return requires_grads; - } - - py::list create_input(py::args args) { - py::list inputs = parameters(); - inputs += args; - return inputs; - } - - void update_parameters(py::object parameter_values) { - py::list params; - if (py::isinstance(parameter_values)) { - params = parameter_values; - } else if (py::isinstance(parameter_values)) { - params = py::list(parameter_values); - } else { - throw py::type_error("parameter_values must be a list or tuple"); - } - - int index = 0; - - std::function update_params; - - update_params = [&](Module& module) { - for (auto& module_item : module.get_attributes_of_type(AttributeType::Module)) { - update_params(py::cast(module_item.second)); - } - - for (auto& param : module.get_attributes_of_type(AttributeType::Parameter)) { - if (index >= py::len(params)) { - throw py::index_error("Provided more than " + std::to_string(index) + " values, but expected " + std::to_string(py::len(params))); - } - module.setattr(param.first, params[index]); - index++; - } - - for (auto& array : module.get_attributes_of_type(AttributeType::ParameterArray)) { - ParameterArray& param_array = py::cast(array.second); - for (auto& param : param_array._parameters) { - if (index >= py::len(params)) { - throw py::index_error("Provided more than " + std::to_string(index) + " values, but expected " + std::to_string(py::len(params))); - } - param_array.setitem(param.first, params[index]); - index++; - } - } - }; - - update_params(*this); - } - - virtual py::object loss(py::object X, py::object Y) { - throw std::runtime_error("Not implemented"); - } - - virtual py::object forward(py::object X) { - throw std::runtime_error("Not implemented"); - } -}; - -} \ No newline at end of file diff --git a/TensorFrost/Frontend/Python/PyTensor.cpp b/TensorFrost/Frontend/Python/PyTensor.cpp deleted file mode 100644 index a1ec2047..00000000 --- a/TensorFrost/Frontend/Python/PyTensor.cpp +++ /dev/null @@ -1,226 +0,0 @@ -#include "Frontend/Python/PyTensor.h" - -#include "PyTensorMemory.h" - -namespace TensorFrost { - -PyTensors PyTensorsFromTuple(const py::tuple& tuple) { - PyTensors tensors; - for (auto arg : tuple) { - tensors.push_back(&arg.cast()); - } - return tensors; -} - -Tensors TensorsFromTuple(const py::tuple& tuple) { - Tensors tensors; - for (auto arg : tuple) { - tensors.push_back(&arg.cast().Get()); - } - return tensors; -} - -tuple SliceToTensors(const py::slice& slice) { - PyObject* pyslice = slice.ptr(); - PySliceObject* slice_obj = (PySliceObject*)pyslice; - PyObject* start = slice_obj->start; - PyObject* stop = slice_obj->stop; - PyObject* step = slice_obj->step; - - py::object start_obj = py::reinterpret_borrow(start); - py::object stop_obj = py::reinterpret_borrow(stop); - py::object step_obj = py::reinterpret_borrow(step); - - PyTensor* start_tensor = nullptr; - PyTensor* stop_tensor = nullptr; - PyTensor* step_tensor = nullptr; - - start_tensor = &start_obj.cast(); - stop_tensor = &stop_obj.cast(); - step_tensor = &step_obj.cast(); - - return {start_tensor, stop_tensor, step_tensor}; -} - -//Tensors TensorsFromTensorIndices(const Tensor* t, const py::tuple& tuple) { -// Tensors tensors; -// for (auto arg : tuple) { -// //if index is a tensor -// if (py::isinstance(arg)) { -// tensors.push_back(&arg.cast().Get()); -// } // if index is a slice -// else if (py::isinstance(arg)) { -// auto slice = arg.cast(); -// //get native python slice -// PyObject* pyslice = slice.ptr(); -// PySliceObject* slice_obj = (PySliceObject*)pyslice; -// //get start, stop, and step -// PyObject* start = slice_obj->start; -// PyObject* stop = slice_obj->stop; -// PyObject* step = slice_obj->step; -// -// //convert start, stop, and step to py::object -// py::object start_obj = py::reinterpret_borrow(start); -// py::object stop_obj = py::reinterpret_borrow(stop); -// py::object step_obj = py::reinterpret_borrow(step); -// -// //try to cast to PyTensor -// PyTensor* start_tensor = nullptr; -// PyTensor* stop_tensor = nullptr; -// PyTensor* step_tensor = nullptr; -// -// try { -// start_tensor = &start_obj.cast(); -// } catch (const py::cast_error& e) { -// //do nothing -// } -// -// try { -// stop_tensor = &stop_obj.cast(); -// } catch (const py::cast_error& e) { -// //do nothing -// } -// -// try { -// step_tensor = &step_obj.cast(); -// } catch (const py::cast_error& e) { -// //do nothing -// } -// -// //if start, stop, and step are all PyTensor -// -// -// -// -// -// else { -// throw std::invalid_argument("Invalid index type"); -// } -// } -// return tensors; -//} - -PyTensors PyTensorsFromList(const py::list& list) { - PyTensors tensors; - for (auto arg : list) { - tensors.push_back(&arg.cast()); - } - return tensors; -} - -Tensors TensorsFromList(const py::list& list) { - Tensors tensors; - for (auto arg : list) { - tensors.push_back(&arg.cast().Get()); - } - return tensors; -} - -PyTensors PyTensorsFromTensors(const Tensors& tensors) { - PyTensors py_tensors; - for (const auto* tensor : tensors) { - py_tensors.push_back(new PyTensor(tensor)); - } - return py_tensors; -} - -std::variant PyTensorsToTupleVariant(const PyTensors &tensors) { - if (tensors.size() == 1) { - //if there is only one tensor, return the tensor - return tensors[0]; - } else { - //convert to py::tuple of PyTensor* - return py::tuple(py::cast(tensors)); - } -} - -void UpdateTensorNames() { - PyObject* p = PyEval_GetLocals(); - py::dict all_names = py::reinterpret_borrow(p ? p : py::module_::import("__main__").attr("__dict__").ptr()); - - for (auto item : all_names) { - std::string var_name = py::str(item.first); - py::object var_value = py::reinterpret_borrow(item.second); - if (py::isinstance(var_value)) { - PyTensor& py_tensor = var_value.cast(); - const Tensor* tensor = &py_tensor.Get(); - tensor->SetDebugName(var_name); - } - } -} - -std::vector GetFunctionArguments(const py::function& func) { - PyObject* fn = func.ptr(); - PyObject* code_obj = PyFunction_GetCode(fn); - if (!code_obj) { - throw std::runtime_error("Could not retrieve code object"); - } - - PyObject* varnames = PyObject_GetAttrString(code_obj, "co_varnames"); - if (!varnames || !PyTuple_Check(varnames)) { - Py_XDECREF(varnames); - throw std::runtime_error("Could not retrieve varnames or varnames is not a tuple"); - } - - PyObject* argcount_obj = PyObject_GetAttrString(code_obj, "co_argcount"); - if (!argcount_obj) { - Py_XDECREF(varnames); - throw std::runtime_error("Could not retrieve argument count"); - } - int arg_count = PyLong_AsLong(argcount_obj); - Py_XDECREF(argcount_obj); - if (PyErr_Occurred()) { - Py_XDECREF(varnames); - throw std::runtime_error("Could not retrieve argument count"); - } - - PyObject* annotations = PyObject_GetAttrString(fn, "__annotations__"); - PyObject* defaults = PyObject_GetAttrString(fn, "__defaults__"); - - std::vector arg_info_list; - for (int i = 0; i < arg_count; ++i) { - PyObject* name = PyTuple_GetItem(varnames, i); - if (!name || !PyUnicode_Check(name)) { - Py_XDECREF(varnames); - Py_XDECREF(annotations); - Py_XDECREF(defaults); - throw std::runtime_error("Argument name is not a valid Unicode string"); - } - std::string arg_name = PyUnicode_AsUTF8(name); - - // Get annotation - PyObject* annotation = annotations ? PyDict_GetItemString(annotations, arg_name.c_str()) : nullptr; - - // Get default value - PyObject* default_val = (defaults && PyTuple_Check(defaults) && i >= (arg_count - PyTuple_Size(defaults))) - ? PyTuple_GetItem(defaults, i - (arg_count - PyTuple_Size(defaults))) - : nullptr; - - py::object annotation_obj = py::reinterpret_borrow(annotation); - py::object default_obj = py::reinterpret_borrow(default_val); - - arg_info_list.emplace_back(arg_name, annotation_obj, default_obj); - } - - Py_XDECREF(varnames); - Py_XDECREF(annotations); - Py_XDECREF(defaults); - - return arg_info_list; -} - -py::array ListToArray(py::list input_list) { - // Get the numpy module - py::module np = py::module::import("numpy"); - - // Convert the list to a numpy array - py::array np_array = np.attr("array")(input_list); - - return np_array; -} - -std::string r_op(const std::string& name) { return "__r" + name + "__"; } - -std::string l_op(const std::string& name) { return "__" + name + "__"; } - -} // namespace TensorFrost diff --git a/TensorFrost/Frontend/Python/PyTensor.h b/TensorFrost/Frontend/Python/PyTensor.h deleted file mode 100644 index 84a50c4f..00000000 --- a/TensorFrost/Frontend/Python/PyTensor.h +++ /dev/null @@ -1,104 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace TensorFrost { - -#define PT(tensor) PyTensor(&(tensor)) -#define T(tensor) (tensor).Get() - -namespace py = pybind11; - -void UpdateTensorNames(); -class PyTensor; - -using PyTensors = std::vector; - -PyTensors PyTensorsFromTuple(const py::tuple& tuple); -Tensors TensorsFromTuple(const py::tuple& tuple); -PyTensors PyTensorsFromList(const py::list& list); -Tensors TensorsFromList(const py::list& list); -PyTensors PyTensorsFromTensors(const Tensors& tensors); -std::variant PyTensorsToTupleVariant(const PyTensors& tensors); - -using ArgInfo = std::tuple; // (name, annotation, default) - -vector GetFunctionArguments(const py::function& func); - -py::array ListToArray(py::list input_list); - -// Tensor wrapper for python -class PyTensor { - Tensor* tensor_; - Tensors indices; - Tensor* value = nullptr; - - public: - explicit PyTensor(Tensor* tensor) : tensor_(tensor) {} - explicit PyTensor(const Tensor* tensor) : tensor_(const_cast(tensor)) {} - ~PyTensor() { UpdateTensorNames(); } - - //tensor view constructor - explicit PyTensor(const Tensor* value, Tensors& indices) - : value(const_cast(value)), indices(std::move(indices)) { - tensor_ = &Tensor::Load(*value, this->indices); - } - - const Tensor& Get() const { return *tensor_; } - - Tensor* Value() const { - if (value == nullptr) { - throw std::runtime_error("Not a tensor view"); - } - return value; - } - - Tensors Indices() const { - if (value == nullptr) { - throw std::runtime_error("Not a tensor view"); - } - return indices; - } - - explicit PyTensor(float value) { tensor_ = &Tensor::Constant(value); } - explicit PyTensor(int value) { tensor_ = &Tensor::Constant(value); } - explicit PyTensor(unsigned int value) { tensor_ = &Tensor::Constant(value); } - explicit PyTensor(bool value) { tensor_ = &Tensor::Constant(value); } - - std::variant __enter__() { - //py::print("Entering node scope"); - std::variant entered = tensor_->Enter(); - if (std::holds_alternative(entered)) { - return new PyTensor(std::get(entered)); - } else { - auto tensors = std::get(entered); - //convert to py::tuple of PyTensor* - return PyTensorsToTupleVariant(PyTensorsFromTensors(Reverse(tensors))); - } - } - - void __exit__(py::object exc_type, py::object exc_value, - py::object traceback) { - //py::print("Exiting node scope"); - tensor_->Exit(); - } -}; - -class PyTensorArg { -public: - std::vector shape; - TFDataFormat type; - - PyTensorArg(std::vector shape, TFDataFormat type) - : shape(std::move(shape)), type(type) {} -}; - -std::string r_op(const std::string& name); -std::string l_op(const std::string& name); - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Frontend/Python/PyTensorMemory.cpp b/TensorFrost/Frontend/Python/PyTensorMemory.cpp deleted file mode 100644 index 75b1ee34..00000000 --- a/TensorFrost/Frontend/Python/PyTensorMemory.cpp +++ /dev/null @@ -1,106 +0,0 @@ -#include "Frontend/Python/PyTensor.h" -#include "Frontend/Python/PyTensorMemory.h" - -namespace TensorFrost { -PyTensorMemory::PyTensorMemory(py::array arr) { - py::buffer_info info = arr.request(); - - // Get the shape - std::vector shape; - for (size_t i = 0; i < (size_t)info.ndim; i++) { - shape.push_back(info.shape[i]); - } - - // Create the data vector - std::vector data; - data.reserve(info.size); - - // Determine the data type and conversion function - std::function convert; - TFDataFormat format = TFTypeNone; - switch (info.format[0]) { - case 'f': // float32 - convert = [](char* ptr) { return *reinterpret_cast(ptr); }; - format = TFTypeFloat32; - break; - case 'i': // int32 - convert = [](char* ptr) { return static_cast(*reinterpret_cast(ptr)); }; - format = TFTypeInt32; - break; - case 'q': // int64 (convert to int32 before casting) - convert = [](char* ptr) { - int32_t val = (int32_t)*reinterpret_cast(ptr); - return *reinterpret_cast(&val); - }; - format = TFTypeInt32; - break; - case 'Q': // uint64 (convert to uint32 before casting) - convert = [](char* ptr) { - uint32_t val = (uint32_t)*reinterpret_cast(ptr); - return val; - }; - format = TFTypeUint32; - break; - case 'L': // uint32 - case 'I': // uint32 - convert = [](char* ptr) { return *reinterpret_cast(ptr); }; - format = TFTypeUint32; - break; - case '?': // bool - convert = [](char* ptr) { return static_cast(*reinterpret_cast(ptr)); }; - format = TFTypeBool32; - break; - case 'd': // float64 (convert to float32 before casting) - convert = [](char* ptr) { float val = (float)*reinterpret_cast(ptr); return *reinterpret_cast(&val); }; - format = TFTypeFloat32; - break; - case 'l': // int64 (convert to int32 before casting) - convert = [](char* ptr) { int32_t val = (int32_t)*reinterpret_cast(ptr); return *reinterpret_cast(&val); }; - format = TFTypeInt32; - break; - default: - throw std::runtime_error("Unsupported data type to create TensorMemory from numpy array, format: " + std::string(info.format)); - } - - // Define a recursive lambda function for multi-dimensional iteration - std::function&)> iter_dims; - iter_dims = [&](const size_t dim, std::vector& indices) { - if (dim == info.ndim) { - // Calculate the actual memory address using strides - char* ptr = static_cast(info.ptr); - for (size_t i = 0; i < (size_t)info.ndim; ++i) { - ptr += indices[i] * info.strides[i]; - } - data.push_back(convert(ptr)); - } else { - for (indices[dim] = 0; indices[dim] < (size_t)info.shape[dim]; ++indices[dim]) { - iter_dims(dim + 1, indices); - } - } - }; - - // Start the multidimensional iteration - std::vector start_indices(info.ndim, 0); - iter_dims(0, start_indices); - - // Allocate the memory - tensor_ = global_memory_manager->AllocateTensorWithData(shape, data, format); -} - -vector TensorMemoryFromTuple(const py::tuple& tuple) { - vector memories; - for (auto arg : tuple) { - memories.push_back(&arg.cast()); - } - return memories; -} - -vector TensorMemoryFromList(const py::list& list) { - vector memories; - for (auto arg : list) { - memories.push_back(&arg.cast()); - } - return memories; -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Frontend/Python/PyTensorMemory.h b/TensorFrost/Frontend/Python/PyTensorMemory.h deleted file mode 100644 index e5ca77ed..00000000 --- a/TensorFrost/Frontend/Python/PyTensorMemory.h +++ /dev/null @@ -1,60 +0,0 @@ -#include -#include -#include -#include -#include -#include - -namespace TensorFrost { - -namespace py = pybind11; - -// Tensor wrapper for python -class PyTensorMemory { - public: - TFTensor* tensor_; - - explicit PyTensorMemory(TFTensor* tensor) : tensor_(tensor) {} - - PyTensorMemory(vector shape, TFDataFormat type = TFTypeFloat32) { - tensor_ = global_memory_manager->AllocateTensor(shape, type); - } - - TFDataFormat GetFormat() const { - return tensor_->format; - } - - PyTensorMemory(py::array arr); - - template - py::array_t ToPyArray() const { - // Get the shape - std::vector shape = GetShape(tensor_); - - // Create the numpy array - py::array_t arr(shape); - - // Copy the data - std::vector data = global_memory_manager->Readback(tensor_); - T* ptr = static_cast(arr.request().ptr); - for (int i = 0; i < data.size(); i++) { - ptr[i] = *(reinterpret_cast(&data[i])); - } - - return arr; - } - - vector Shape() const { - return TensorFrost::GetShape(tensor_); - } - - ~PyTensorMemory() { - global_memory_manager->DeallocateTensor(*tensor_); - } - -}; - -vector TensorMemoryFromTuple(const py::tuple& tuple); -vector TensorMemoryFromList(const py::list& list); - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Frontend/Python/PybindModule.cpp b/TensorFrost/Frontend/Python/PybindModule.cpp deleted file mode 100644 index c3351736..00000000 --- a/TensorFrost/Frontend/Python/PybindModule.cpp +++ /dev/null @@ -1,128 +0,0 @@ -#include -#include - -#include -#include -#include - -namespace TensorFrost { - -void PyTensorDefinition(py::module&, py::class_&); -void TensorFunctionsDefinition(py::module&); -void TensorProgramDefinition(py::module&, py::class_&); -void TensorMemoryDefinition(py::module& m, - py::class_& py_tensor_mem); -void WindowDefinitions(py::module& m); -void ScopeDefinitions(py::module& m, py::class_& py_tensor); -void ModuleDefinitions(py::module& m); - -PYBIND11_MODULE(TensorFrost, m) { - m.doc() = "TensorFrost library"; - auto data_type = py::enum_(m, "TFType"); - auto backend_type = py::enum_(m, "BackendType"); - auto code_gen_lang = py::enum_(m, "CodeGenLang"); - auto py_tensor = py::class_(m, "Tensor"); - auto tensor_program = py::class_(m, "TensorProgram"); - auto py_tensor_mem = py::class_(m, "TensorMemory"); - auto py_tensor_arg = py::class_(m, "Arg"); - - data_type.value("float", TFType::Float); - data_type.value("int", TFType::Int); - data_type.value("uint", TFType::Uint); - data_type.value("bool", TFType::Bool); - - auto data_format = py::class_(m, "TFDataFormat"); - data_format.def(py::init()); - data_format.def_readwrite("type", &TFDataFormat::type); - data_format.def_readwrite("size", &TFDataFormat::size); - // Add printers for the enums - data_format.def("__repr__", [](const TFDataFormat& a) { - return ""; - }); - data_format.def("__str__", [](const TFDataFormat& a) { - return ""; - }); - data_type.def("__repr__", [](TFType a) { - return ""; - }); - data_type.def("__str__", [](TFType a) { - return ""; - }); - - backend_type.value("cpu", BackendType::CPU); - backend_type.value("vulkan", BackendType::Vulkan); - backend_type.value("opengl", BackendType::OpenGL); - backend_type.value("codegen", BackendType::CodeGen); - code_gen_lang.value("cpp", CodeGenLang::CPP); - code_gen_lang.value("glsl", CodeGenLang::GLSL); - code_gen_lang.value("hlsl", CodeGenLang::HLSL); - - data_format.def(py::self == py::self); - backend_type.def("__eq__", [](BackendType a, BackendType b) { - return a == b; - }); - code_gen_lang.def("__eq__", [](CodeGenLang a, CodeGenLang b) { - return a == b; - }); - data_type.def("__eq__", [](TFType a, TFType b) { - return a == b; - }); - - m.attr("float32") = TFTypeFloat32; - m.attr("int32") = TFTypeInt32; - m.attr("uint32") = TFTypeUint32; - m.attr("bool1") = TFTypeBool32; - - m.attr("cpu") = BackendType::CPU; - m.attr("vulkan") = BackendType::Vulkan; - m.attr("opengl") = BackendType::OpenGL; - m.attr("codegen") = BackendType::CodeGen; - - m.attr("cpp_lang") = CodeGenLang::CPP; - m.attr("glsl_lang") = CodeGenLang::GLSL; - m.attr("hlsl_lang") = CodeGenLang::HLSL; - - PyTensorDefinition(m, py_tensor); - - // implicit conversion from TensorView to PyTensor - py::implicitly_convertible(); - py::implicitly_convertible(); - py::implicitly_convertible(); - py::implicitly_convertible(); - - TensorFunctionsDefinition(m); - TensorProgramDefinition(m, tensor_program); - TensorMemoryDefinition(m, py_tensor_mem); - WindowDefinitions(m); - ScopeDefinitions(m, py_tensor); - ModuleDefinitions(m); - - py_tensor_arg.def(py::init([](py::list shape, TFDataFormat type) { - std::vector shape_vec; - for (auto& s : shape) { - shape_vec.push_back(s.cast()); - } - return PyTensorArg(shape_vec, type); - }), "Create a TensorArg with the given shape and type", py::return_value_policy::take_ownership); - - m.def("current_backend", []() { - return current_backend; - }, "Get the current backend"); - - m.def("initialize", - [](BackendType backend_type, const std::string& kernel_compile_options, CodeGenLang kernel_lang) { - InitializeBackend(backend_type, kernel_compile_options, kernel_lang); - }, py::arg("backend_type") = BackendType::CPU, py::arg("kernel_compile_options") = "", py::arg("kernel_lang") = CodeGenLang::None, "Initialize the backend"); - - m.def("strip_debug_info", [](bool strip) { - strip_debug_names = strip; - }, py::arg("strip") = true, "Strip debug info from the kernel"); - -#ifdef NDEBUG - py::print("TensorFrost module loaded!"); -#else - py::print("TensorFrost module loaded in debug mode! Expect slow performance."); -#endif -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/IR/include/Overloads.h b/TensorFrost/IR/include/Overloads.h deleted file mode 100644 index 97262d4d..00000000 --- a/TensorFrost/IR/include/Overloads.h +++ /dev/null @@ -1,236 +0,0 @@ -#pragma once -#include "Operation.h" - -namespace TensorFrost { - -Op& make_op(std::string op, std::vector mem, std::vector ids, std::vector args, std::vector shape); - -template -Op& func_op(std::string op, const Args&... args) { - std::vector mem; - std::vector ids; - std::vector args_vec = {&args...}; - std::vector shape; - return make_op(op, mem, ids, args_vec, shape); -} - -Op& constant(int value); -Op& constant(uint value); -Op& constant(float value); -Op& constant(bool value); - -template -concept Num = std::is_arithmetic_v>; - -template -inline Op& as_op(T v) -{ - using D = std::remove_cvref_t; - using Target = - std::conditional_t, bool, - std::conditional_t, float, - std::conditional_t, unsigned int, - int>>>; - return constant(static_cast(v)); -} - -#define UNARY_OPERATOR(op, name) \ -template \ -Op& operator op(const T& a) { \ - return func_op(name, as_op(a)); \ -} - -#define BINARY_OPERATOR(op, name) \ -template \ -Op& operator op(const T& a, const U& b) { \ - return func_op(name, as_op(a), as_op(b)); \ -} - -#define UNARY_FUNCTION(name, opname) \ -template \ -Op& name(const T& a) { \ - return func_op(opname, as_op(a)); \ -} - -#define BINARY_FUNCTION(name, opname) \ -template \ -Op& name(const T& a, const U& b) { \ - return func_op(opname, as_op(a), as_op(b)); \ -} - -#define TERNARY_FUNCTION(name, opname) \ -template \ -Op& name(const T& cond, const U& x, const V& y) { \ - return func_op(opname, as_op(cond), as_op(x), as_op(y)); \ -} - -UNARY_OPERATOR(+, "pos") -UNARY_OPERATOR(-, "neg") -UNARY_OPERATOR(~, "not") -UNARY_OPERATOR(!, "lnot") - -BINARY_OPERATOR(+, "add") -BINARY_OPERATOR(-, "sub") -BINARY_OPERATOR(*, "mul") -BINARY_OPERATOR(/, "div") -BINARY_OPERATOR(%, "mod") -BINARY_OPERATOR(&, "and") -BINARY_OPERATOR(|, "or") -BINARY_OPERATOR(^, "xor") -BINARY_OPERATOR(<<, "lshift") -BINARY_OPERATOR(>>, "rshift") -BINARY_OPERATOR(==, "eq") -BINARY_OPERATOR(!=, "neq") -BINARY_OPERATOR(<, "lt") -BINARY_OPERATOR(<=, "lte") -BINARY_OPERATOR(>, "gt") -BINARY_OPERATOR(>=, "gte") -BINARY_OPERATOR(&&, "land") -BINARY_OPERATOR(||, "lor") - -UNARY_FUNCTION(copy, "copy") -UNARY_FUNCTION(sin, "sin") -UNARY_FUNCTION(cos, "cos") -UNARY_FUNCTION(tan, "tan") -UNARY_FUNCTION(asin, "asin") -UNARY_FUNCTION(acos, "acos") -UNARY_FUNCTION(atan, "atan") -UNARY_FUNCTION(sinh, "sinh") -UNARY_FUNCTION(cosh, "cosh") -UNARY_FUNCTION(tanh, "tanh") -UNARY_FUNCTION(asinh, "asinh") -UNARY_FUNCTION(acosh, "acosh") -UNARY_FUNCTION(atanh, "atanh") -UNARY_FUNCTION(exp, "exp") -UNARY_FUNCTION(log, "log") -UNARY_FUNCTION(log2, "log2") -UNARY_FUNCTION(exp2, "exp2") -UNARY_FUNCTION(sqrt, "sqrt") -UNARY_FUNCTION(sqr, "sqr") -UNARY_FUNCTION(rsqrt, "rsqrt") -UNARY_FUNCTION(rcp, "rcp") -UNARY_FUNCTION(abs, "abs") -UNARY_FUNCTION(sign, "sign") -UNARY_FUNCTION(floor, "floor") -UNARY_FUNCTION(ceil, "ceil") -UNARY_FUNCTION(round, "round") -UNARY_FUNCTION(trunc, "trunc") -UNARY_FUNCTION(frac, "frac") -UNARY_FUNCTION(pcg, "pcg") -UNARY_FUNCTION(pcgf, "pcgf") -UNARY_FUNCTION(reversebits, "reversebits") -UNARY_FUNCTION(tofloat, "tofloat") -UNARY_FUNCTION(toint, "toint") -UNARY_FUNCTION(touint, "touint") -UNARY_FUNCTION(tobool, "tobool") -UNARY_FUNCTION(asfloat, "asfloat") -UNARY_FUNCTION(asint, "asint") -UNARY_FUNCTION(asuint, "asuint") -UNARY_FUNCTION(clamp, "clamp") - -BINARY_FUNCTION(pow, "pow") -BINARY_FUNCTION(min, "min") -BINARY_FUNCTION(max, "max") -BINARY_FUNCTION(mod, "mod") -BINARY_FUNCTION(modf, "modf") -BINARY_FUNCTION(atan2, "atan2") -BINARY_FUNCTION(grad, "backwards_grad") - -TERNARY_FUNCTION(lerp, "lerp") -TERNARY_FUNCTION(smoothstep, "smoothstep") -TERNARY_FUNCTION(select, "ternary") -TERNARY_FUNCTION(fma, "fma") - - -// Arithmetic operations -Op& operator+(const Op& a, const Op& b); -Op& operator-(const Op& a, const Op& b); -Op& operator*(const Op& a, const Op& b); -Op& operator/(const Op& a, const Op& b); -Op& operator%(const Op& a, const Op& b); - -// Bitwise operations -Op& operator&(const Op& a, const Op& b); -Op& operator|(const Op& a, const Op& b); -Op& operator^(const Op& a, const Op& b); -Op& operator<<(const Op& a, const Op& b); -Op& operator>>(const Op& a, const Op& b); -Op& operator~(const Op& a); - -// Comparison operations -Op& operator==(const Op& a, const Op& b); -Op& operator!=(const Op& a, const Op& b); -Op& operator<(const Op& a, const Op& b); -Op& operator<=(const Op& a, const Op& b); -Op& operator>(const Op& a, const Op& b); -Op& operator>=(const Op& a, const Op& b); - -// Logical operations -Op& operator&&(const Op& a, const Op& b); -Op& operator||(const Op& a, const Op& b); -Op& operator!(const Op& a); - -// Increment and decrement operations -Op& operator++(const Op& a); -Op& operator--(const Op& a); - -Op& operator+=(const Op& a, const Op& b); -Op& operator-=(const Op& a, const Op& b); - -Op& copy(const Op& a); -Op& sin(const Op& a); -Op& cos(const Op& a); -Op& tan(const Op& a); -Op& asin(const Op& a); -Op& acos(const Op& a); -Op& atan(const Op& a); -Op& sinh(const Op& a); -Op& cosh(const Op& a); -Op& tanh(const Op& a); -Op& asinh(const Op& a); -Op& acosh(const Op& a); -Op& atanh(const Op& a); -Op& exp(const Op& a); -Op& log(const Op& a); -Op& log2(const Op& a); -Op& exp2(const Op& a); -Op& sqrt(const Op& a); -Op& sqr(const Op& a); -Op& rsqrt(const Op& a); -Op& rcp(const Op& a); -Op& abs(const Op& a); -Op& sign(const Op& a); -Op& floor(const Op& a); -Op& ceil(const Op& a); -Op& round(const Op& a); -Op& trunc(const Op& a); -Op& frac(const Op& a); - -Op& pcg(const Op& a); -Op& pcgf(const Op& a); - -Op& reversebits(const Op& a); - -Op& tofloat(const Op& a); -Op& toint(const Op& a); -Op& touint(const Op& a); -Op& tobool(const Op& a); - -Op& asfloat(const Op& a); -Op& asint(const Op& a); -Op& asuint(const Op& a); - -Op& clamp(const Op& x, const Op& min, const Op& max); -Op& pow(const Op& x, const Op& y); -Op& min(const Op& x, const Op& y); -Op& max(const Op& x, const Op& y); -Op& mod(const Op& x, const Op& y); -Op& modf(const Op& x, const Op& y); -Op& atan2(const Op& x, const Op& y); -Op& grad(const Op& x, const Op& wrt); -Op& lerp(const Op& x, const Op& y, const Op& a); -Op& smoothstep(const Op& a, const Op& b, const Op& x); -Op& select(const Op& cond, const Op& x, const Op& y); -Op& fma(const Op& x, const Op& y, const Op& z); - -} diff --git a/TensorFrost/IR/src/Overloads.cpp b/TensorFrost/IR/src/Overloads.cpp deleted file mode 100644 index ce2fc259..00000000 --- a/TensorFrost/IR/src/Overloads.cpp +++ /dev/null @@ -1,149 +0,0 @@ -#include "../include/Overloads.h" -#include "../include/ExecutionContext.h" -#include "../include/OperationRegistry.h" -#include "../include/OperationArguments.h" - -using namespace TensorFrost; -using namespace std; - -// General function to create an Op instance in the current execution context -Op& make_op(string op, vector mem, vector ids, vector args, vector shape) { - OpSpec* spec = GetOpSpec(op); - vector arg_types; - for (const auto& arg : args) { - arg_types.push_back(arg->type); - } - TFDataFormat output_type = spec->GetOutputType(arg_types); - Op* op_instance = new Op(op); - op_instance->type = output_type; - op_instance->args->SetArguments(ArgType::Memory, mem); - op_instance->args->SetArguments(ArgType::Index, ids); - op_instance->args->SetArguments(ArgType::Input, args); - op_instance->args->SetArguments(ArgType::Shape, shape); - return GetContext()->AddOp(std::unique_ptr(op_instance)); -} - -Op& constant(int value) { - Op& const_op = func_op("const"); - const_op.attributes["value"] = value; - const_op.type = TFTypeInt32; - return const_op; -} - -Op& constant(uint value) { - Op& const_op = func_op("const"); - const_op.attributes["value"] = value; - const_op.type = TFTypeUint32; - return const_op; -} - -Op& constant(float value) { - Op& const_op = func_op("const"); - const_op.attributes["value"] = value; - const_op.type = TFTypeFloat32; - return const_op; -} - -Op& constant(bool value) { - Op& const_op = func_op("const"); - const_op.attributes["value"] = value; - const_op.type = TFTypeBool32; - return const_op; -} - -// Arithmetic operations -Op& operator+(const Op& a, const Op& b) { return func_op("add", a, b); } -Op& operator-(const Op& a, const Op& b) { return func_op("sub", a, b); } -Op& operator*(const Op& a, const Op& b) { return func_op("mul", a, b); } -Op& operator/(const Op& a, const Op& b) { return func_op("div", a, b); } -Op& operator%(const Op& a, const Op& b) { return func_op("mod", a, b); } - -// Bitwise operations -Op& operator&(const Op& a, const Op& b) { return func_op("and", a, b); } -Op& operator|(const Op& a, const Op& b) { return func_op("or", a, b); } -Op& operator^(const Op& a, const Op& b) { return func_op("xor", a, b); } -Op& operator<<(const Op& a, const Op& b) { return func_op("lshift", a, b); } -Op& operator>>(const Op& a, const Op& b) { return func_op("rshift", a, b); } -Op& operator~(const Op& a) { return func_op("not", a); } - -// Comparison operations -Op& operator==(const Op& a, const Op& b) { return func_op("eq", a, b); } -Op& operator!=(const Op& a, const Op& b) { return func_op("neq", a, b); } -Op& operator<(const Op& a, const Op& b) { return func_op("lt", a, b); } -Op& operator<=(const Op& a, const Op& b) { return func_op("lte", a, b); } -Op& operator>(const Op& a, const Op& b) { return func_op("gt", a, b); } -Op& operator>=(const Op& a, const Op& b) { return func_op("gte", a, b); } - -// Logical operations -Op& operator&&(const Op& a, const Op& b) { return func_op("land", a, b); } -Op& operator||(const Op& a, const Op& b) { return func_op("lor", a, b); } -Op& operator!(const Op& a) { return func_op("lnot", a); } - -// Assignment operations -// Op& operator+=(const Op& a, const Op& b) { return func_op("add_assign", a, b); } -// Op& operator-=(const Op& a, const Op& b) { return func_op("sub_assign", a, b); } -// Op& operator*=(const Op& a, const Op& b) { return func_op("mul_assign", a, b); } -// Op& operator/=(const Op& a, const Op& b) { return func_op("div_assign", a, b); } -// Op& operator%=(const Op& a, const Op& b) { return func_op("mod_assign", a, b); } -// Op& operator&=(const Op& a, const Op& b) { return func_op("and_assign", a, b); } -// Op& operator|=(const Op& a, const Op& b) { return func_op("or_assign", a, b); } -// Op& operator^=(const Op& a, const Op& b) { return func_op("xor_assign", a, b); } -// Op& operator<<=(const Op& a, const Op& b) { return func_op("lshift_assign", a, b); } -// Op& operator>>=(const Op& a, const Op& b) { return func_op("rshift_assign", a, b); } -// Op& operator++(const Op& a) { return a += 1; } -// Op& operator--(const Op& a) { return a -= 1; } - -Op& copy(const Op& a) { return func_op("copy", a); } -Op& sin(const Op& a) { return func_op("sin", a); } -Op& cos(const Op& a) { return func_op("cos", a); } -Op& tan(const Op& a) { return func_op("tan", a); } -Op& asin(const Op& a) { return func_op("asin", a); } -Op& acos(const Op& a) { return func_op("acos", a); } -Op& atan(const Op& a) { return func_op("atan", a); } -Op& sinh(const Op& a) { return func_op("sinh", a); } -Op& cosh(const Op& a) { return func_op("cosh", a); } -Op& tanh(const Op& a) { return func_op("tanh", a); } -Op& asinh(const Op& a) { return func_op("asinh", a); } -Op& acosh(const Op& a) { return func_op("acosh", a); } -Op& atanh(const Op& a) { return func_op("atanh", a); } -Op& exp(const Op& a) { return func_op("exp", a); } -Op& log(const Op& a) { return func_op("log", a); } -Op& log2(const Op& a) { return func_op("log2", a); } -Op& exp2(const Op& a) { return func_op("exp2", a); } -Op& sqrt(const Op& a) { return func_op("sqrt", a); } -Op& sqr(const Op& a) { return func_op("sqr", a); } -Op& rsqrt(const Op& a) { return func_op("rsqrt", a); } -Op& rcp(const Op& a) { return func_op("rcp", a); } -Op& abs(const Op& a) { return func_op("abs", a); } -Op& sign(const Op& a) { return func_op("sign", a); } -Op& floor(const Op& a) { return func_op("floor", a); } -Op& ceil(const Op& a) { return func_op("ceil", a); } -Op& round(const Op& a) { return func_op("round", a); } -Op& trunc(const Op& a) { return func_op("trunc", a); } -Op& frac(const Op& a) { return func_op("frac", a); } -Op& pcg(const Op& a) { return func_op("pcg", a); } -Op& pcgf(const Op& a) { return func_op("pcgf", a); } -Op& reversebits(const Op& a) { return func_op("reversebits", a); } - -Op& clamp(const Op& x, const Op& min, const Op& max) { return func_op("clamp", x, min, max); } -Op& pow(const Op& x, const Op& y) { return func_op("pow", x, y); } -Op& min(const Op& x, const Op& y) { return func_op("min", x, y); } -Op& max(const Op& x, const Op& y) { return func_op("max", x, y); } -Op& mod(const Op& x, const Op& y) { return func_op("mod", x, y); } -Op& modf(const Op& x, const Op& y) { return func_op("modf", x, y); } -Op& atan2(const Op& x, const Op& y) { return func_op("atan2", x, y); } -Op& grad(const Op& x, const Op& wrt) { return func_op("backwards_grad", x, wrt); } -Op& lerp(const Op& x, const Op& y, const Op& a) { return func_op("lerp", x, y, a); } -Op& smoothstep(const Op& a, const Op& b, const Op& x) { return func_op("smoothstep", a, b, x); } -Op& select(Op& cond, const Op& x, const Op& y) { return func_op("ternary", cond, x, y); } -Op& fma(const Op& x, const Op& y, const Op& z) { return func_op("fma", x, y, z); } - -// Type conversion operations -Op& tofloat(const Op& a) { return func_op("tofloat", a); } -Op& toint(const Op& a) { return func_op("toint", a); } -Op& touint(const Op& a) { return func_op("touint", a); } -Op& tobool(const Op& a) { return func_op("tobool", a); } - -Op& asfloat(const Op& a) { return func_op("asfloat", a); } -Op& asint(const Op& a) { return func_op("asint", a); } -Op& asuint(const Op& a) { return func_op("asuint", a); } \ No newline at end of file diff --git a/TensorFrost/PybindModule.cpp b/TensorFrost/PybindModule.cpp new file mode 100644 index 00000000..4de7cf13 --- /dev/null +++ b/TensorFrost/PybindModule.cpp @@ -0,0 +1,134 @@ +#include +#include + +// #include +// #include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +namespace TensorFrost { +// +// void PyTensorDefinition(py::module&, py::class_&); +// void TensorFunctionsDefinition(py::module&); +// void TensorProgramDefinition(py::module&, py::class_&); +// void TensorMemoryDefinition(py::module& m, +// py::class_& py_tensor_mem); +// void WindowDefinitions(py::module& m); +// void ScopeDefinitions(py::module& m, py::class_& py_tensor); +// void ModuleDefinitions(py::module& m); + +PYBIND11_MODULE(TensorFrost, m) { + m.doc() = "TensorFrost library"; + // auto data_type = py::enum_(m, "TFType"); + // auto backend_type = py::enum_(m, "BackendType"); + // auto code_gen_lang = py::enum_(m, "CodeGenLang"); + // auto py_tensor = py::class_(m, "Tensor"); + // auto tensor_program = py::class_(m, "TensorProgram"); + // auto py_tensor_mem = py::class_(m, "TensorMemory"); + // auto py_tensor_arg = py::class_(m, "Arg"); + // + // data_type.value("float", TFType::Float); + // data_type.value("int", TFType::Int); + // data_type.value("uint", TFType::Uint); + // data_type.value("bool", TFType::Bool); + // + // auto data_format = py::class_(m, "TFDataFormat"); + // data_format.def(py::init()); + // data_format.def_readwrite("type", &TFDataFormat::type); + // data_format.def_readwrite("size", &TFDataFormat::size); + // // Add printers for the enums + // data_format.def("__repr__", [](const TFDataFormat& a) { + // return ""; + // }); + // data_format.def("__str__", [](const TFDataFormat& a) { + // return ""; + // }); + // data_type.def("__repr__", [](TFType a) { + // return ""; + // }); + // data_type.def("__str__", [](TFType a) { + // return ""; + // }); + // + // backend_type.value("cpu", BackendType::CPU); + // backend_type.value("vulkan", BackendType::Vulkan); + // backend_type.value("opengl", BackendType::OpenGL); + // backend_type.value("codegen", BackendType::CodeGen); + // code_gen_lang.value("cpp", CodeGenLang::CPP); + // code_gen_lang.value("glsl", CodeGenLang::GLSL); + // code_gen_lang.value("hlsl", CodeGenLang::HLSL); + // + // data_format.def(py::self == py::self); + // backend_type.def("__eq__", [](BackendType a, BackendType b) { + // return a == b; + // }); + // code_gen_lang.def("__eq__", [](CodeGenLang a, CodeGenLang b) { + // return a == b; + // }); + // data_type.def("__eq__", [](TFType a, TFType b) { + // return a == b; + // }); + // + // m.attr("float32") = TFTypeFloat32; + // m.attr("int32") = TFTypeInt32; + // m.attr("uint32") = TFTypeUint32; + // m.attr("bool1") = TFTypeBool32; + // + // m.attr("cpu") = BackendType::CPU; + // m.attr("vulkan") = BackendType::Vulkan; + // m.attr("opengl") = BackendType::OpenGL; + // m.attr("codegen") = BackendType::CodeGen; + // + // m.attr("cpp_lang") = CodeGenLang::CPP; + // m.attr("glsl_lang") = CodeGenLang::GLSL; + // m.attr("hlsl_lang") = CodeGenLang::HLSL; + // + // PyTensorDefinition(m, py_tensor); + // + // // implicit conversion from TensorView to PyTensor + // py::implicitly_convertible(); + // py::implicitly_convertible(); + // py::implicitly_convertible(); + // py::implicitly_convertible(); + // + // TensorFunctionsDefinition(m); + // TensorProgramDefinition(m, tensor_program); + // TensorMemoryDefinition(m, py_tensor_mem); + // WindowDefinitions(m); + // ScopeDefinitions(m, py_tensor); + // ModuleDefinitions(m); + // + // py_tensor_arg.def(py::init([](py::list shape, TFDataFormat type) { + // std::vector shape_vec; + // for (auto& s : shape) { + // shape_vec.push_back(s.cast()); + // } + // return PyTensorArg(shape_vec, type); + // }), "Create a TensorArg with the given shape and type", py::return_value_policy::take_ownership); + // + // m.def("current_backend", []() { + // return current_backend; + // }, "Get the current backend"); + // + // m.def("initialize", + // [](BackendType backend_type, const std::string& kernel_compile_options, CodeGenLang kernel_lang) { + // InitializeBackend(backend_type, kernel_compile_options, kernel_lang); + // }, py::arg("backend_type") = BackendType::CPU, py::arg("kernel_compile_options") = "", py::arg("kernel_lang") = CodeGenLang::None, "Initialize the backend"); + // + // m.def("strip_debug_info", [](bool strip) { + // strip_debug_names = strip; + // }, py::arg("strip") = true, "Strip debug info from the kernel"); + +#ifdef NDEBUG + py::print("TensorFrost module loaded!"); +#else + py::print("TensorFrost module loaded in debug mode! Expect slow performance."); +#endif +} + +} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Tensor/Tensor.cpp b/TensorFrost/Tensor/Tensor.cpp deleted file mode 100644 index 000620d5..00000000 --- a/TensorFrost/Tensor/Tensor.cpp +++ /dev/null @@ -1,357 +0,0 @@ -#include "Tensor.h" -#include "Backend/CodeGen/Generators.h" - -namespace TensorFrost { - -Tensors Reverse(const Tensors& tensors) { - Tensors reversed; - for (int i = (int)tensors.size() - 1; i >= 0; i--) { - reversed.push_back(tensors[i]); - } - return reversed; -} - -vector Reverse(const vector& vec) { - vector reversed; - for (int i = (int)vec.size() - 1; i >= 0; i--) { - reversed.push_back(vec[i]); - } - return reversed; -} - -int ReverseDim(int dim, size_t dims) { - return (int)dims - dim - 1; -} - -IR* Tensor::evaluation_context_ir_ = nullptr; - -Node::~Node() { delete tensor_; } - -std::string MakeNodeErrorMessage(std::string message, std::initializer_list nodes) { - message += "\n"; - for (auto node : nodes) { - message += GetNodeString(node) + "\n"; - } - return message; -} - -vector ShapeInfo::GetShape(int default_value) const { - vector shape; - for (auto [node, _]: this->shape) { - if(node->op->class_ == OpClass::Constant) { - shape.push_back(node->tensor_->TryGetConstant()); - } else { - shape.push_back(default_value); - } - } - return shape; -} - -void ShapeInfo::ExpandDimensionsTo(int new_dim) -{ - if(new_dim <= dim) { - return; - } - Tensor& one = Tensor::Constant(1); - for(int i = dim; i < new_dim; i++) { - InsertDim(i, one.node_, true); - } -} - -float ShapeInfo::GetSizeEstimate(ShapeInfo& shape) { - vector shape_a = shape.GetShape(); - float size_a = 1.0f; - for (int i = 0; i < shape_a.size(); i++) { - size_a *= (float)shape_a[i]; - } - return size_a; -} - -void ArgumentManager::AddArgument(ArgID id, Node* node) { - if(node == nullptr) { - throw std::runtime_error("Node is null"); - } - inputs_[id] = node; - argument_types_[id] = node->format; - argument_counts_[id.first]++; - //add this node as an output of the argument - node->args.AddOutput(id, node_); -} - -void ArgumentManager::Remove(ArgID id) { - if(inputs_.find(id) == inputs_.end()) { - throw std::runtime_error("Cannot remove argument that does not exist"); - } - //remove this node as an output of the argument - inputs_[id]->args.RemoveOutput(id, node_); - inputs_.erase(id); - argument_types_.erase(id); - argument_counts_[id.first]--; -} - -void ArgumentManager::RemoveArguments(ArgType arg) { - vector to_remove; - for (auto& [id, node] : inputs_) { - if (id.first == arg) { - to_remove.push_back(id); - } - } - for (auto& id : to_remove) { - Remove(id); - } -} - -vector ArgumentManager::GetTensorVector(ArgType type) const { - vector tensors = vector(Count(type)); - for (auto& [id, node] : inputs_) { - if (id.first == type) { - tensors[id.second] = node->tensor_; - } - } - return tensors; -} - -tuple Tensor::GetOperation(const string &name, const Tensors &tensors, - bool check_shape) { - vector input_types = vector(); - for (const auto& tensor : tensors) { - input_types.push_back(tensor->node_->format); - } - - const Operation* operation = FindOperation(name); - - // check if input is valid - if (!operation->IsInputValid(input_types)) { - string error = "Input types ("; - for (int i = 0; i < input_types.size(); i++) { - error += DataTypeToString(input_types[i].type) + "(" + to_string(input_types[i].size) + ")"; - if (i < input_types.size() - 1) { - error += ", "; - } - } - error += ") are not valid for operation \"" + name + "\""; - throw std::runtime_error(error); - } - - ShapeInfo shape_info = ShapeInfo(); - - if (check_shape) - { - //check if shapes are compatible and get the final broadcasted shape - for (int i = 0; i < tensors.size(); i++) { - ShapeInfo shape_info2 = ShapeInfo(tensors[i]->node_); - auto result = CompareShape(shape_info, shape_info2, true); - shape_info = result.broadcast_shape; - } - } - - TFDataFormat output_type = operation->GetOutputType(input_types); - - return {operation, output_type, shape_info}; -} - -bool Tensor::CheckIndices(const Tensors &indices) { - for (const Tensor* index : indices) { - if (index->node_->format.type != TFType::Int && index->node_->format.type != TFType::Uint) { - return false; - } - } - return true; -} - -TFDataFormat Tensor::GetFormat() const { return node_->format; } - -void Tensor::SetData(const vector &data) const { - node_->data = data; -} - -void Tensor::SetData(uint data) const { - SetData(vector(1, data)); -} - -void Tensor::SetData(float data) const { - SetData(vector(1, AsUint(data))); -} - -void Tensor::SetData(int data) const { - SetData(vector(1, AsUint(data))); -} - -void Tensor::SetFormat(TFDataFormat type) const { - node_->format = type; -} - -void Tensor::DetachGrad() const { - node_->flags.set(NodeProp::DetachGrad); -} - -void Tensor::PassGrad() const { - node_->flags.set(NodeProp::PassGrad); -} - -void Tensor::StopFusion() const { - node_->flags.set(NodeProp::StopFusion); -} - -void Tensor::HintRange(float min, float max) const { - node_->flags.set(NodeProp::HintMinValue, (int64_t)AsUint(min)); - node_->flags.set(NodeProp::HintMaxValue, (int64_t)AsUint(max)); -} - -void Tensor::HintRange(int min, int max) const { - node_->flags.set(NodeProp::HintMinValue, (int64_t)min); - node_->flags.set(NodeProp::HintMaxValue, (int64_t)max); -} - -void Tensor::HintRange(uint min, uint max) const { - node_->flags.set(NodeProp::HintMinValue, (int64_t)min); - node_->flags.set(NodeProp::HintMaxValue, (int64_t)max); -} - -Tensor* Tensor::GetCopy(const Tensor& other, NodeArguments args) { - Tensor* copy = &CreateNode(other.node_->format, std::move(args), other.node_->name); - copy->node_->data = other.node_->data; - copy->node_->CopyProperties(other.node_); - return copy; -} - -Tensor* Tensor::GetCopy(const Tensor& other) { - NodeArguments new_args; - for (auto& [id, from] : other.node_->args.Inputs()) { - new_args[id] = from; - } - return GetCopy(other, new_args); -} - -void Tensor::SetShape(Tensors shape) const { - node_->args.RemoveArguments(ArgType::Shape); - for (int i = 0; i < shape.size(); i++) { - node_->args.AddArgument(ArgType::Shape, i, shape[i]->node_); - } -} - -Tensors Tensor::GetInputShapeTensors(Tensors shape) { - Tensors result = Tensors(); - for (int dim = 0; dim < shape.size(); dim++) { - const Tensor* tensor = shape[dim]; - //check if tensor is a negative constant - if (tensor->node_->name == "const" && (*(int*)&(tensor->node_->data[0])) < 0) - { - Tensor& mem = Static("input_shape", {TFType::Int, 32}); - //make sure its reversed on the backend - mem.node_->flags.set(NodeProp::InputShapeDim, (int64_t)(shape.size() - dim - 1)); - result.push_back(&mem); - } - else - { - result.push_back(tensor); - } - } - return result; -} - -//Get values from a tensor at the given indices -Tensor& Tensor::Load(const Tensor& tensor, const Tensors& indices, IndexingMode mode) { - Tensor& out = MemoryOp("load", &tensor, indices); - out.node_->indexing_mode_ = mode; - out.SetData(0); - out.SetDebugName(tensor.node_->debug_name); - return out; -} - -Tensor& Tensor::Store(const Tensor& tensor, const Tensor& value, - const Tensors& indices, IndexingMode mode) { - Tensor& out = MemoryOp("store", &tensor, indices, &value); - out.node_->indexing_mode_ = mode; - return out; -} - -Tensor & Tensor::ReductionOP(string name, const Tensor &tensor, int axis, bool keepdims) { - // get the shape of the tensor (all dimensions except the last one) - Tensors shape = tensor.GetShape(); - axis = GetAxis((int)shape.size(), axis); - - //check if axis is valid - if (axis < 0 || axis >= shape.size()) { - throw std::runtime_error("Invalid axis for reduction operation " + name); - } - - // remove the axis dimension - shape.erase(shape.begin() + axis); - if (shape.empty()) { - shape.push_back(&Constant(1)); - } - Tensor& op = OpShape(name, shape, &tensor); - op.node_->data = vector(1, axis); - //if(keepdims) { - // op.node_->AddFlag(NodeFlag::KeepDims); - //} - return op; -} - -Tensor & Tensor::ScanOP(string name, const Tensor &tensor, int axis) { - Tensor& op = Op(name, &tensor); - op.node_->data = vector(1, axis); - return op; -} - -bool Tensor::AreTensorsEqual(const Tensor &a, const Tensor &b) { - if(a.node_->op->class_ == OpClass::Constant && b.node_->op->class_ == OpClass::Constant) { - return a.node_->data[0] == b.node_->data[0]; - } - if(a.node_ == b.node_) { - return true; - } - return false; -} - -Tensor& Tensor::Reshape(const Tensor& tensor, const Tensors& shape, TFDataFormat format) { - Tensor& out = MemoryOpShape("reshape", shape, &tensor); - out.SetDebugName(tensor.node_->debug_name); - if(format.type != TFType::None) { - out.node_->format = format; - } else { - out.node_->format = tensor.node_->format; - } - return out; -} - -Tensor & Tensor::Assert(const Tensor &tensor, const Tensors &shape, TFDataFormat type) { - Tensor& out = MemoryOpShape("assert", shape, &tensor); - out.SetDebugName(tensor.node_->debug_name); - out.node_->format = type; - return out; -} - -void Tensor::SetDebugName(const string& name) const -{ - if (name != "") { - node_->debug_name = name; - } - - if (strip_debug_names) { - static int count = 0; - - node_->debug_name = "tensor_" + std::to_string(count++); - } -} - -void Tensor::BeginRegion(const string& name) { - Tensor& t = Static("region_begin", {TFType::None, 0}); - t.SetDebugName(name); -} - -void Tensor::EndRegion(const string& name) { - Tensor& t = Static("region_end", {TFType::None, 0}); - t.SetDebugName(name); -} - -const Tensor* Node::GetTensor() const { - if (tensor_->node_ != this) { - throw std::runtime_error("Fatal Error: Tensor node does not match"); - } - return tensor_; -} - - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Tensor/Tensor.h b/TensorFrost/Tensor/Tensor.h deleted file mode 100644 index 40deee5e..00000000 --- a/TensorFrost/Tensor/Tensor.h +++ /dev/null @@ -1,1228 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -// for FLT_MAX, INT_MAX, etc. -#include -#include - -#include - -#include "Compiler/Graph/IR.h" -#include "Utility/Utility.h" - -namespace TensorFrost { - -using Tensors = vector; - -Tensors Reverse(const Tensors& tensors); -vector Reverse(const vector& vec); -int ReverseDim(int dim, size_t dims); - -class Tensor { - private: - static IR* evaluation_context_ir_; - - static Tensor& CreateNode(TFDataFormat type, NodeArguments args, string name) { - if (evaluation_context_ir_ == nullptr) { - throw std::runtime_error( - "Evaluation context has not been set. Are you doing operations " - "without compiling first?"); - } - - //TODO: merge Tensor and Node classes - auto* tensor = new Tensor(); - tensor->node_ = evaluation_context_ir_->AddNode(tensor, std::move(args), std::move(name), type); - return *tensor; - } - - static void AddArgument(NodeArguments& arguments, const Tensor* tensor, - ArgType type, int index = 0) { - arguments[ArgID(type, index)] = tensor->node_; - } - - static void AddArguments(NodeArguments& arguments, const Tensors& tensors, - ArgType type) { - for (int i = 0; i < tensors.size(); i++) { - AddArgument(arguments, tensors[i], type, i); - } - } - - static void AddArguments(NodeArguments& arguments, const NodeArguments& toadd) { - for (const auto& i : toadd) { - arguments[i.first] = i.second; - } - } - - static bool AssertTensorShape(const Tensor* a, const Tensor* b, bool throw_error = true) { - return CompareShape(a->node_, b->node_, throw_error).compatible; - } - - static tuple GetOperation(const string& name, - const Tensors& tensors, bool check_shape = true); - - template - static Tensor& Op(std::string op, const Args*... args) { - op = RemoveSpaces(op); - - if (op.empty()) { - throw std::runtime_error("Operation name cannot be empty"); - } - - // convert the parameter pack to a std::vector - Tensors tensors = {args...}; - - // get the operation and output type - auto [operation, output_type, shape_info] = GetOperation(op, tensors); - - // create argument list - NodeArguments arguments = NodeArguments(); - - AddArguments(arguments, tensors, ArgType::Input); - - AddArguments(arguments, shape_info.GetArguments()); - - return CreateNode(output_type, arguments, op); - } - -public: - static Tensor& OpShape(std::string op, Tensors shape, Tensors tensors) { - op = RemoveSpaces(op); - - if (op.empty()) { - throw std::runtime_error("Operation name cannot be empty"); - } - - // get the operation and output type - auto [operation, output_type, shape_info] = GetOperation(op, tensors, false); - - // create argument list - NodeArguments arguments = NodeArguments(); - - AddArguments(arguments, tensors, ArgType::Input); - AddArguments(arguments, shape, ArgType::Shape); - - return CreateNode(output_type, arguments, op); - } - - template - static Tensor& OpShape(std::string op, Tensors shape, const Args*... args) { - // convert the parameter pack to a std::vector - Tensors tensors = {args...}; - - return OpShape(op, shape, tensors); - } - - template - static Tensor& MemoryOpShape(std::string op, Tensors shape, const Tensor* memory, const Args*... args) { - op = RemoveSpaces(op); - - if (op.empty()) { - throw std::runtime_error("Memory operation name cannot be empty"); - } - - // convert the parameter pack to a std::vector - Tensors tensors = {args...}; - - // get the operation and output type - auto [operation, output_type, shape_info] = GetOperation(op, tensors); - - // create argument list - NodeArguments arguments = NodeArguments(); - - AddArgument(arguments, memory, ArgType::Memory); - AddArguments(arguments, tensors, ArgType::Input); - AddArguments(arguments, shape, ArgType::Shape); - - return CreateNode(output_type, arguments, op); - } - - static bool CheckIndices(const Tensors& indices); - - template - static Tensor& MemoryOp(string op, const Tensor* memory, - const Tensors indices, const Args*... args) { - op = RemoveSpaces(op); - - if (op.empty()) { - throw std::runtime_error("Memory operation name cannot be empty"); - } - - if(!CheckIndices(indices)) { - throw std::runtime_error("Tensor indices must be integers"); - } - - //check if memory op is local - bool is_local = false; - if(memory->node_->op->HasAllTypes(OpProp::LocalMemory)) { - is_local = true; - } - - //if local, can only have 1 index - if(is_local && indices.size() > 1) { - throw std::runtime_error("Local memory operations can only have 1 index"); - } - - //if global, can only have up to memory dimension indices - size_t memory_dim = memory->GetDimension(); - if(!is_local && indices.size() > memory_dim) { - throw std::runtime_error("Too many indices for memory operation, memory has " + std::to_string(memory_dim) + " dimensions, while " + std::to_string(indices.size()) + " indices were provided"); - } - - // convert the parameter pack to a std::vector - Tensors tensors = {args...}; - - // get the operation and output type - auto [operation, output_type, shape_info] = GetOperation(op, tensors); - - if (operation->HasAllTypes(OpProp::Modifier)) - { - memory->node_->flags.set(NodeProp::Modified); - } - - // create argument list - NodeArguments arguments = NodeArguments(); - - AddArgument(arguments, memory, ArgType::Memory); - AddArguments(arguments, tensors, ArgType::Input); - AddArguments(arguments, indices, ArgType::Index); - - // get an input node that has shape arguments - NodeArguments shape_arguments = shape_info.GetArguments(); - - //use index shape instead if no input shape is found - if (shape_arguments.empty()) - { - for (const Tensor* index : indices) - { - shape_arguments = index->node_->args.GetArguments(ArgType::Shape); - if (!shape_arguments.empty()) { - break; - } - } - } - - //if no indices or inputs exist, use memory shape - if (indices.empty() && tensors.empty()) - { - shape_arguments = memory->node_->args.GetArguments(ArgType::Shape); - } - - AddArguments(arguments, shape_arguments); - - if (op == "load") output_type = memory->GetFormat(); - - Tensor& output = CreateNode(output_type, arguments, op); - if(is_local) output.node_->flags.set(NodeProp::LocalMemoryOp); - return output; - } - - static Tensor& Static(string op, const NodeArguments& shape, - const TFDataFormat format) { - op = RemoveSpaces(op); - - if (op.empty()) { - throw std::runtime_error("Static operation name cannot be empty"); - } - - const Operation* operation = FindOperation(op); - // check if output is valid - if (!operation->IsOutputValid(format)) { - throw std::runtime_error("Type " + DataTypeToString(format.type) + "(" + std::to_string(format.size) + ") is not valid for operation \"" + op + "\""); - } - NodeArguments arguments = NodeArguments(); - AddArguments(arguments, shape); - return CreateNode(format, arguments, op); - } - - static Tensor& Static(const string& op, const Tensors& shape, - const TFDataFormat type) { - NodeArguments arguments = NodeArguments(); - AddArguments(arguments, shape, ArgType::Shape); - return Static(op, arguments, type); - } - - static Tensor& Static(const string& op, const TFDataFormat type) { - return Static(op, NodeArguments(), type); - } - - static void SetEvaluationContext(IR* ir) { - if (evaluation_context_ir_ != nullptr && ir != nullptr) { - throw std::runtime_error("Evaluation context change is forbidden."); - } - evaluation_context_ir_ = ir; - } - - static IR* GetEvaluationContext() { - return evaluation_context_ir_; - } - - string GetConstantString() const; - - static Tensor& CustomOperation(const string & name, Tensors inputs, Tensors shape) { - return OpShape(name, shape, inputs); - } - - Node* node_ = nullptr; - - TFDataFormat GetFormat() const; - void SetData(const vector& data) const; - void SetData(uint data) const; - void SetData(float data) const; - void SetData(int data) const; - void SetFormat(TFDataFormat type) const; - void DetachGrad() const; - void PassGrad() const; - void StopFusion() const; - void HintRange(float min, float max) const; - void HintRange(int min, int max) const; - void HintRange(uint min, uint max) const; - - static Tensor* GetCopy(const Tensor& other, NodeArguments args); - - static Tensor* GetCopy(const Tensor& other); - - void SetMemoryType(NodeProp memory_type, int index = 0) const { - node_->SetMemoryType(memory_type, index); - } - - int GetDimension() const { - ShapeInfo shape_info = ShapeInfo(node_); - return shape_info.dim; - } - - Tensors GetShape() const { - ShapeInfo shape_info = ShapeInfo(node_); - return shape_info.GetTensors(); - } - - Tensors GetReverseShape() const { - Tensors shape = GetShape(); - std::reverse(shape.begin(), shape.end()); - return shape; - } - - ShapeInfo GetShapeInfo() const { - return ShapeInfo(node_); - } - - void SetShape(Tensors shape) const; - - int TryGetConstant() const { - if (node_->name != "const") { - return -1; - } - return AsInt(node_->data[0]); - } - - // tensor factory methods - static Tensors GetConstantShape(const vector& shape) { - Tensors result = Tensors(); - for (int i : shape) { - result.push_back(&Constant(i)); - } - return result; - } - - static Tensor& Constant(float value) { - Tensor& output = Static("const", TFTypeFloat32); - output.SetData(AsUint(value)); - return output; - } - static Tensor& Constant(int value) { - Tensor& output = Static("const", TFTypeInt32); - output.SetData(AsUint(value)); - return output; - } - static Tensor& Constant(uint value) { - Tensor& output = Static("const", TFTypeUint32); - output.SetData(value); - return output; - } - static Tensor& Constant(bool value) { - Tensor& output = Static("const", TFTypeBool32); - output.SetData(value); - return output; - } - static Tensor& Constant(uint value, TFDataFormat type) { - Tensor& output = Static("const", type); - output.SetData(value); - return output; - } - static Tensor& Constant(uint value, const Tensors& shape, TFDataFormat type) { - NodeArguments arguments = NodeArguments(); - AddArguments(arguments, shape, ArgType::Shape); - Tensor& output = Static("const", arguments, type); - output.SetData(value); - return output; - } - - static Tensor& Constant(const Tensors& shape, float value) { - NodeArguments arguments = NodeArguments(); - AddArguments(arguments, shape, ArgType::Shape); - Tensor& output = Static("const", arguments, TFTypeFloat32); - output.SetData(value); - return output; - } - static Tensor& Constant(const vector& shape, float value) { - return Constant(GetConstantShape(shape), value); - } - static Tensor& Constant(const Tensors& shape, int value) { - NodeArguments arguments = NodeArguments(); - AddArguments(arguments, shape, ArgType::Shape); - Tensor& output = Static("const", arguments, TFTypeInt32); - output.SetData(value); - return output; - } - static Tensor& Constant(const vector& shape, int value) { - return Constant(GetConstantShape(shape), value); - } - - static Tensor& Constant(const Tensors& shape, uint value) { - NodeArguments arguments = NodeArguments(); - AddArguments(arguments, shape, ArgType::Shape); - Tensor& output = Static("const", arguments, TFTypeUint32); - output.SetData(value); - return output; - } - - static Tensor& Constant(const vector& shape, uint value) { - return Constant(GetConstantShape(shape), value); - } - static Tensor& Constant(const Tensors shape, uint value, TFDataFormat type) { - NodeArguments arguments = NodeArguments(); - AddArguments(arguments, shape, ArgType::Shape); - Tensor& output = Static("const", arguments, type); - output.SetData(value); - return output; - } - - static Tensors GetShapeTensors(const vector& shape) { - Tensors result = Tensors(); - for (int i : shape) { - result.push_back(&Constant(i)); - } - return result; - } - - static Tensor& Memory(const TFDataFormat type) { return Static("memory", type); } - static Tensor& Memory(const Tensors& shape, - const TFDataFormat type = TFTypeFloat32) { - return Static("memory", shape, type); - } - static Tensor& Memory(const NodeArguments& shape, - const TFDataFormat type = TFTypeFloat32) { - return Static("memory", shape, type); - } - static Tensor& Memory(const vector& shape, - const TFDataFormat type = TFTypeFloat32) { - return Memory(GetShapeTensors(shape), type); - } - - static Tensor& LocalMemory(const int size, const TFDataFormat type) { - Tensor& output = Static("local_memory", type); - output.SetData(size); - return output; - } - - static Tensor& GroupMemory(const int size, const TFDataFormat type) { - Tensor& output = Static("group_memory", type); - output.SetData(size); - return output; - } - - static void GroupBarrier() { - Op("group_barrier"); - } - - static Tensors GetInputShapeTensors(Tensors shape); - - static Tensor& Input(const TFDataFormat type = TFTypeFloat32) { - Tensor& output = Memory(type); - output.SetMemoryType(NodeProp::InputMemory); - return output; - } - static Tensor& Input(const Tensors& shape, - const TFDataFormat type = TFTypeFloat32) { - Tensor& output = Memory(GetInputShapeTensors(shape), type); - output.SetMemoryType(NodeProp::InputMemory); - return output; - } - static Tensor& Input(const vector& shape, - const TFDataFormat type = TFTypeFloat32) { - return Input(GetShapeTensors(shape), type); - } - - static Tensor& Index(NodeArguments shape, int dim) { - Tensor& output = Static("dim_id", shape, TFTypeInt32); - output.SetData(dim); - output.SetFormat(TFTypeInt32); - return output; - } - - static Tensor& Index(Tensors shape, int dim) { - Tensor& output = Static("dim_id", shape, TFTypeInt32); - output.SetData(dim); - output.SetFormat(TFTypeInt32); - return output; - } - - static Tensor& Index(const vector& shape, int dim) { - return Index(GetConstantShape(shape), dim); - } - - static Tensors Indices(Tensors shape) { - int dims = (int)shape.size(); - Tensors indices = Tensors(); - for (int i = 0; i < dims; i++) { - indices.push_back(&Index(shape, i)); - } - return indices; - } - - static Tensor& FlatIndex(Tensors shape, Tensors indices) { - int memory_dim = (int)shape.size(); - if(memory_dim == 0) return Constant(0); - // compute the flat index (C-order) - Tensor* flat_index = const_cast(indices[0]); - for (int i = 1; i < memory_dim; i++) { - flat_index = &(*flat_index * *shape[i]); - flat_index = &(*flat_index + *indices[i]); - } - return *flat_index; - } - - static Tensors IndicesFromFlatIndex(const Tensor* index, Tensors shape) - { - size_t dims = shape.size(); - Tensors indices = Tensors(dims); - Tensors sizes = Tensors(dims); - sizes[0] = shape[0]; - for (size_t i = 1; i < dims - 1; i++) { - sizes[i] = &(*sizes[i - 1] * *shape[i]); - } - - Tensor* temp; - for (size_t i = 0; i < dims; i++) { - Tensor* idx0 = const_cast(index); - if (i < dims - 1) { - idx0 = &(*idx0 / *sizes[dims - i - 2]); - } - if (i > 0) { - temp = &(*temp * *shape[dims - i - 1]); - idx0 = &(*idx0 - *temp); - if (i != dims - 1) temp = &(*temp + *idx0); - } else { - temp = idx0; - } - indices[dims - i - 1] = idx0; - } - - return indices; - } - - static Tensor& ElementIndex(Tensors shape) { - return FlatIndex(shape, Indices(shape)); - } - - static Tensor& GetSeed(Tensors shape, const Tensor& seed) { - Tensor* full_seed = &const_cast(seed); - if(full_seed->node_->format.type != TFType::Uint) { - full_seed = &asuint(*full_seed); //convert seed to uint - } - full_seed = &(touint(ElementIndex(shape)) + *full_seed * Constant(2654435761u)); - return *full_seed; - } - - static Tensor& Hash(Tensors shape, const Tensor& seed) { - return pcg(GetSeed(shape, seed)); - } - - static Tensor& Random(Tensors shape, const Tensor& seed) { - return pcgf(GetSeed(shape, seed)); - } - - Tensor& BlockIndex() const { - Tensor& output = Static( - "block_id", node_->args.GetArguments(ArgType::Shape), TFTypeInt32); - output.SetFormat(TFTypeInt32); - return output; - } - - Tensor& BlockThreadIndex(int i) const { - Tensor& output = Static( - "block_thread_id", node_->args.GetArguments(ArgType::Shape), TFTypeInt32); - output.SetFormat(TFTypeInt32); - output.SetData(i); - return output; - } - - static Tensor& Load(const Tensor& tensor, const Tensors& indices = Tensors(), - IndexingMode mode = IndexingMode::Clamp); - - static Tensor& Deallocate(const Tensor& tensor) { - return MemoryOp("deallocate", &tensor, {}); - } - - Tensor& Index(int dim) const { - Tensor& output = Static("dim_id", node_->args.GetArguments(ArgType::Shape), TFTypeInt32); - output.SetData(dim); - output.SetFormat(TFTypeInt32); - return output; - } - - Tensors Indices() const { - return Indices(GetShape()); - } - - static Tensor& Store(const Tensor& tensor, const Tensor& value, - const Tensors& indices = Tensors(), IndexingMode mode = IndexingMode::Clamp); - - void Set(const Tensor& value) const { - //check if memory and value shapes are compatible - ShapeCompareResult shape_result = CompareShape(node_, value.node_, true); - if (!shape_result.compatible) { - throw std::runtime_error("Cannot set tensor with incompatible shape"); - } - MemoryOp("set", this, {}, &value); - //update the shape of the tensor - SetShape(shape_result.broadcast_shape.GetTensors()); - } - - static void ScatterAdd(const Tensor& tensor, const Tensor& value, - const Tensors& indices, IndexingMode mode = IndexingMode::Clamp) { - MemoryOp("InterlockedAdd", &tensor, indices, &value); - } - - static Tensor& ScatterAddPrev(const Tensor& tensor, const Tensor& value, - const Tensors& indices, IndexingMode mode = IndexingMode::Clamp) { - Tensor& a = MemoryOp("InterlockedAdd_Prev", &tensor, indices, &value); - a.node_->indexing_mode_ = mode; - return a; - } - - static void ScatterMax(const Tensor& tensor, const Tensor& value, - const Tensors& indices, IndexingMode mode = IndexingMode::Clamp) { - Tensor& a = MemoryOp("InterlockedMax", &tensor, indices, &value); - a.node_->indexing_mode_ = mode; - } - - static void ScatterMin(const Tensor& tensor, const Tensor& value, - const Tensors& indices, IndexingMode mode = IndexingMode::Clamp) { - Tensor& a = MemoryOp("InterlockedMin", &tensor, indices, &value); - a.node_->indexing_mode_ = mode; - } - - static void ScatterOr(const Tensor& tensor, const Tensor& value, - const Tensors& indices, IndexingMode mode = IndexingMode::Clamp) { - Tensor& a = MemoryOp("InterlockedOr", &tensor, indices, &value); - a.node_->indexing_mode_ = mode; - } - - static void ScatterAnd(const Tensor& tensor, const Tensor& value, - const Tensors& indices, IndexingMode mode = IndexingMode::Clamp) { - Tensor& a = MemoryOp("InterlockedAnd", &tensor, indices, &value); - a.node_->indexing_mode_ = mode; - } - - static void ScatterXor(const Tensor& tensor, const Tensor& value, - const Tensors& indices, IndexingMode mode = IndexingMode::Clamp) { - Tensor& a = MemoryOp("InterlockedXor", &tensor, indices, &value); - a.node_->indexing_mode_ = mode; - } - - static int GetAxis(int dims, int axis) { - if (axis < 0) { - axis = dims + axis; - } - return axis; - } - - static Tensor& ReductionOP(string name, const Tensor& tensor, int axis = 0, bool keepdims = false); - static Tensor& ScanOP(string name, const Tensor& tensor, int axis = 0); - - static Tensor& Sum(const Tensor& tensor, int axis = 0) { - return ReductionOP("dim_sum", tensor, axis); - } - - static Tensor& Norm(const Tensor& tensor, int axis = 0) { - return ReductionOP("dim_norm", tensor, axis); - } - - static Tensor& Mean(const Tensor& tensor, int axis = 0) { - return ReductionOP("dim_mean", tensor, axis); - } - - static Tensor& Max(const Tensor& tensor, int axis = 0) { - return ReductionOP("dim_max", tensor, axis); - } - - static Tensor& Any(const Tensor& tensor, int axis = 0) { - return ReductionOP("dim_any", tensor, axis); - } - - static Tensor& All(const Tensor& tensor, int axis = 0) { - return ReductionOP("dim_all", tensor, axis); - } - - static Tensor& Min(const Tensor& tensor, int axis = 0) { - return ReductionOP("dim_min", tensor, axis); - } - - static Tensor& PrefixSum(const Tensor& tensor, int axis = 0) { - return ScanOP("dim_prefix_sum", tensor, axis); - } - - static Tensor& Reverse(const Tensor& tensor, int axis = 0) { - Tensors shape = tensor.GetShape(); - int dims = (int)shape.size(); - axis = GetAxis(dims, axis); - Tensor& output = OpShape("dim_reverse", shape, &tensor); - output.SetData(axis); - return output; - } - - static Tensor& Repeat(const Tensor& tensor, const Tensor& repeats, int axis = 0) { - //check if repeats is a scalar - if (repeats.GetDimension() != 0) { - throw std::runtime_error("Repeats argument must be a scalar"); - } - int dims = (int)tensor.GetDimension(); - axis = GetAxis(dims, axis); - Tensors shape = tensor.GetShape(); - Tensors new_shape = Tensors(); - for (int i = 0; i < dims; i++) { - if (i == axis) { - new_shape.push_back(&(*shape[i] * repeats)); - } else { - new_shape.push_back(shape[i]); - } - } - Tensor& output = OpShape("dim_repeat", new_shape, &tensor); - output.SetData(axis); - return output; - } - - static Tensor& SplitDim(const Tensor& tensor, int split_size = 128, int axis = 0) { - ShapeInfo shapeinfo = tensor.GetShapeInfo(); - int dims = shapeinfo.dim; - Tensors shape = shapeinfo.GetTensors(); - axis = GetAxis(dims, axis); - Tensors new_shape = Tensors(); - for (int i = 0; i < dims; i++) { - if (i == axis) { - new_shape.push_back(&Tensor::Constant(split_size)); - new_shape.push_back(&((*shape[i] + Tensor::Constant(split_size - 1)) / Tensor::Constant(split_size))); - } else { - new_shape.push_back(shape[i]); - } - } - Tensor& output = OpShape("dim_split", new_shape, &tensor); - output.SetData({(uint)axis, (uint)split_size}); - return output; - } - - static Tensor& MergeDim(const Tensor& tensor, int axis = 0, const Tensor* target_size = nullptr) { - ShapeInfo shapeinfo = tensor.GetShapeInfo(); - int dims = shapeinfo.dim; - Tensors shape = shapeinfo.GetTensors(); - axis = GetAxis(dims, axis); - if(axis == 0) axis = 1; - const Tensor* target_size_tensor = nullptr; - if(target_size == nullptr) { - target_size_tensor = &(*shape[axis] * *shape[axis+1]); - } else { - target_size_tensor = target_size; - } - axis = GetAxis(dims, axis); - Tensors new_shape = Tensors(); - for (int i = 0; i < dims; i++) { - if(i == axis) { - new_shape.push_back(target_size_tensor); - } else if(i != axis+1) { - new_shape.push_back(shape[i]); - } - } - Tensor& output = OpShape("dim_merge", new_shape, &tensor); - output.SetData(axis); - return output; - } - - static Tensor& Transpose(const Tensor& tensor, const int axis1 = 1, const int axis2 = 0) { - ShapeInfo shapeinfo = tensor.GetShapeInfo(); - - int dims = std::max(std::max(axis1+1, axis2+1), std::max(shapeinfo.dim, -std::min(axis1, axis2))); - int a1 = GetAxis(dims, axis1); - int a2 = GetAxis(dims, axis2); - shapeinfo.ExpandDimensionsTo(dims); - Tensors shape = shapeinfo.GetTensors(); - //swap the axes - std::swap(shape[a1], shape[a2]); - Tensor& output = OpShape("transpose", shape, &tensor); - //add data - output.SetData({AsUint(a1), AsUint(a2)}); - return output; - } - - //dot product of - static Tensor& Dot(const Tensor& tensor1, const Tensor& tensor2, int axis = 0) { - Tensors shape = tensor1.GetShape(); - int dims = (int)shape.size(); - axis = GetAxis(dims, axis); - shape.erase(shape.begin() + axis); - Tensor& output = OpShape("dot", shape, &tensor1, &tensor2); - output.SetData(axis); - return output; - } - - static Tensor& Unsqueeze(const Tensor& tensor, int axis = 0) { - Tensors shape = tensor.GetShape(); - int dims = (int)shape.size(); - if(axis < 0) { - axis = dims + axis + 1; - } - axis = std::max(0, std::min(dims, axis)); - shape.insert(shape.begin() + axis, &Constant(1)); - Tensor& output = OpShape("unsqueeze", shape, &tensor); - output.SetData(axis); - return output; - } - - static bool AreTensorsEqual(const Tensor& a, const Tensor& b); - - static Tensor& Squeeze(const Tensor& tensor, int axis = 0) { - Tensors shape = tensor.GetShape(); - int dims = (int)shape.size(); - axis = GetAxis(dims, axis); - if (shape[axis]->TryGetConstant() != 1) { - throw std::runtime_error("Cannot squeeze a dimension that is not 1"); - } - shape.erase(shape.begin() + axis); - Tensor& output = OpShape("squeeze", shape, &tensor); - output.SetData(axis); - return output; - } - - //takes two tensors [T1, T2, ..., Tn, M, N] and [Tm, .., Tn, N, K] and returns [T1, T2, ..., Tm, M, K] - static Tensor& Matmul(const Tensor& a, const Tensor& b) { - ShapeInfo shape_a = a.GetShapeInfo(); - ShapeInfo shape_b = b.GetShapeInfo(); - - if (shape_a.dim < 2 && shape_b.dim < 2) { - throw std::runtime_error("Matrix multiplication requires at least one 2D tensor"); - } - - if(shape_a.dim < 2) { - shape_a.ExpandDimensionsTo(2); - } - if(shape_b.dim < 2) { - shape_b.ExpandDimensionsTo(2); - } - - Tensors shape_a_tensors = shape_a.GetTensors(); - Tensors shape_b_tensors = shape_b.GetTensors(); - - //get shape of the result - Tensors shape_c = Tensors(); - int dim_a = shape_a.dim; - int dim_b = shape_b.dim; - int max_dim = 0; - Tensors max_shape = Tensors(); - //get the shape with most dimensions - if (dim_a < dim_b) { - max_dim = dim_b; - max_shape = shape_b_tensors; - } else { - max_dim = dim_a; - max_shape = shape_a_tensors; - } - - shape_c.push_back(shape_b_tensors[0]); - shape_c.push_back(shape_a_tensors[1]); - for (int i = 2; i < max_dim; i++) { - shape_c.push_back(max_shape[i]); - } - ShapeDimCompareResult result = CompareShapeDim(shape_a_tensors[0]->node_, shape_b_tensors[1]->node_); - if (!result.compatible) { - throw std::runtime_error("Inner dimensions of the matrices must match"); - } - - Tensor& output = OpShape("matmul", shape_c, &a, &b); - return output; - } - - static Tensor& Reshape(const Tensor &tensor, const Tensors &shape, TFDataFormat format = TFTypeNone); - static Tensor& Assert(const Tensor& tensor, const Tensors& shape, TFDataFormat type = TFTypeFloat32); - - Tensors enter_tensors = Tensors(); - bool already_entered = false; - - std::variant Enter() - { - if(!node_->op->HasAllTypes(OpProp::HasChildren)) { - throw std::runtime_error("The node of type " + node_->name + " cannot have children"); - } - - if(already_entered) { - throw std::runtime_error("Already entered node scope"); - } - evaluation_context_ir_->BeginScopeLastChild(node_); //begin at the last child - already_entered = true; - if(enter_tensors.size() > 0) { - return enter_tensors; //if we have some special info, like indices of kernel threads - } else { - return const_cast(this); - } - } - - void Exit() - { - evaluation_context_ir_->EndScope(); - } - - static Tensor& Loop(const Tensor& start, const Tensor& end, const Tensor& step) - { - return Op("loop", &start, &end, &step); - } - - static void Loop(const Tensor& start, const Tensor& end, const Tensor& step, - const function& body) { - // create the loop - Tensor& loop = Loop(start, end, step); - - evaluation_context_ir_->ExecuteExpressionFirstChild(loop.node_, [&]() { - // create the body - body(loop); - }); - } - - static Tensor& If(const Tensor& condition) { - // create the if - Tensor& if_tensor = Op("if", &condition); - return if_tensor; - } - - static void If(const Tensor& condition, - const std::function& body) { - // create the if - Tensor& if_tensor = If(condition); - - evaluation_context_ir_->ExecuteExpressionFirstChild(if_tensor.node_, [&]() { - // create the body - body(); - }); - } - - static void If(const Tensor& condition, const std::function& true_body, - const std::function& false_body) { - If(condition, true_body); - If(!condition, false_body); - } - - static void Vmap(const Tensors inputs, const Tensors shape, const std::function& body) { - // create the if - Tensor& vmap_main = OpShape("vmap", Tensors(), inputs); - - evaluation_context_ir_->ExecuteExpressionFirstChild(vmap_main.node_, [&]() { - // create the body - body(); - }); - } - - static Tensor& Kernel(const Tensors shape, vector group_size = {}) { - // create the kernel - Tensor& kernel = Static("kernel", shape, TFTypeNone); - evaluation_context_ir_->ExecuteExpressionFirstChild(kernel.node_, [&]() { - for (int i = 0; i < shape.size(); i++) { - kernel.enter_tensors.push_back(&Index(shape, i)); //thread indices - } - }); - kernel.node_->group_size = group_size; - return kernel; - } - - static Tensor& Kernel(const Tensors shape, const std::function& body, vector group_size = {}) { - // create the kernel - Tensor& kernel = Kernel(shape); - - evaluation_context_ir_->ExecuteExpressionLastChild(kernel.node_, [&]() { - // create the body - body(kernel.enter_tensors); - }); - - kernel.node_->group_size = group_size; - return kernel; - } - - static Tensor& Kernel(const NodeArguments& shape) - { - // create the kernel - Tensor& kernel = Static("kernel", shape, TFTypeNone); - return kernel; - } - - static void Break() { - // create the break - Tensor& break_tensor = Static("break", TFTypeNone); - } - - static void Continue() { - // create the continue - Tensor& continue_tensor = Static("continue", TFTypeNone); - } - - // destructor - ~Tensor() = default; - - Tensor& operator-() const { return Op("neg", this); } - Tensor& operator!() const { - if(node_->format.type == TFType::Bool) { - return Op("notb", this); - } else { - return Op("not", this); - } - } - Tensor& operator~() const { - if(node_->format.type == TFType::Bool) { - return Op("notb", this); - } else { - return Op("not", this); - } - } - - Tensor& operator+(const Tensor& other) const { - return Op("add", this, &other); - } - - Tensor& operator-(const Tensor& other) const { - return Op("sub", this, &other); - } - - Tensor& operator*(const Tensor& other) const { - return Op("mul", this, &other); - } - - Tensor& operator/(const Tensor& other) const { - return Op("div", this, &other); - } - - Tensor& operator%(const Tensor& other) const { - return Op("mod", this, &other); - } - - Tensor& operator>(const Tensor& other) const { - return Op("gt", this, &other); - } - - Tensor& operator<(const Tensor& other) const { - return Op("lt", this, &other); - } - - Tensor& operator>=(const Tensor& other) const { - return Op("gte", this, &other); - } - - Tensor& operator<=(const Tensor& other) const { - return Op("lte", this, &other); - } - - Tensor& operator==(const Tensor& other) const { - return Op("eq", this, &other); - } - - Tensor& operator!=(const Tensor& other) const { - return Op("neq", this, &other); - } - - Tensor& operator&&(const Tensor& other) const { - return Op("and", this, &other); - } - - Tensor& operator||(const Tensor& other) const { - return Op("or", this, &other); - } - - Tensor& operator&(const Tensor& other) const { - return Op("and", this, &other); - } - - Tensor& operator|(const Tensor& other) const { - return Op("or", this, &other); - } - - Tensor& operator^(const Tensor& other) const { - return Op("xor", this, &other); - } - - Tensor& operator<<(const Tensor& other) const { - return Op("lshift", this, &other); - } - - Tensor& operator>>(const Tensor& other) const { - return Op("rshift", this, &other); - } - - void operator=(const Tensor& other) = delete; - - static Tensor& copy(const Tensor& tensor) { - return Op("copy", &tensor); - } - - static Tensor& sin(const Tensor& x) { return Op("sin", &x); } - static Tensor& cos(const Tensor& x) { return Op("cos", &x); } - static Tensor& tan(const Tensor& x) { return Op("tan", &x); } - static Tensor& asin(const Tensor& x) { return Op("asin", &x); } - static Tensor& acos(const Tensor& x) { return Op("acos", &x); } - static Tensor& atan(const Tensor& x) { return Op("atan", &x); } - static Tensor& sinh(const Tensor& x) { return Op("sinh", &x); } - static Tensor& cosh(const Tensor& x) { return Op("cosh", &x); } - static Tensor& tanh(const Tensor& x) { return Op("tanh", &x); } - static Tensor& asinh(const Tensor& x) { return Op("asinh", &x); } - static Tensor& acosh(const Tensor& x) { return Op("acosh", &x); } - static Tensor& atanh(const Tensor& x) { return Op("atanh", &x); } - static Tensor& exp(const Tensor& x) { return Op("exp", &x); } - static Tensor& log(const Tensor& x) { return Op("log", &x); } - static Tensor& log2(const Tensor& x) { return Op("log2", &x); } - static Tensor& exp2(const Tensor& x) { return Op("exp2", &x); } - static Tensor& sqrt(const Tensor& x) { return Op("sqrt", &x); } - static Tensor& sqr(const Tensor& x) { return Op("sqr", &x); } - static Tensor& rsqrt(const Tensor& x) { return Op("rsqrt", &x); } - static Tensor& rcp(const Tensor& x) { return Op("rcp", &x); } - static Tensor& abs(const Tensor& x) { return Op("abs", &x); } - static Tensor& sign(const Tensor& x) { return Op("sign", &x); } - static Tensor& floor(const Tensor& x) { return Op("floor", &x); } - static Tensor& ceil(const Tensor& x) { return Op("ceil", &x); } - static Tensor& round(const Tensor& x) { return Op("round", &x); } - static Tensor& trunc(const Tensor& x) { return Op("trunc", &x); } - static Tensor& frac(const Tensor& x) { return Op("frac", &x); } - - static Tensor& pcg(const Tensor& x) { return Op("pcg", &x); } - static Tensor& pcgf(const Tensor& x) { return Op("pcgf", &x); } - - static Tensor& reversebits(const Tensor& x) { return Op("reversebits", &x); } - - static Tensor& tofloat(const Tensor& x) { return Op("float", &x); } - static Tensor& toint(const Tensor& x) { return Op("int", &x); } - static Tensor& touint(const Tensor& x) { return Op("uint", &x); } - static Tensor& tobool(const Tensor& x) { return Op("bool", &x); } - - static Tensor& asfloat(const Tensor& x) { return Op("asfloat", &x); } - static Tensor& asint(const Tensor& x) { return Op("asint", &x); } - static Tensor& asuint(const Tensor& x) { return Op("asuint", &x); } - - static Tensor& clamp(const Tensor& x, const Tensor& min, const Tensor& max) { - return Op("clamp", &x, &min, &max); - } - - static Tensor& pow(const Tensor& x, const Tensor& y) { - return Op("pow", &x, &y); - } - - static Tensor& min(const Tensor& x, const Tensor& y) { - return Op("min", &x, &y); - } - - static Tensor& max(const Tensor& x, const Tensor& y) { - return Op("max", &x, &y); - } - - static Tensor& mod(const Tensor& x, const Tensor& y) { - return Op("mod", &x, &y); - } - - static Tensor& modf(const Tensor& x, const Tensor& y) { - return Op("modf", &x, &y); - } - - static Tensor& atan2(const Tensor& x, const Tensor& y) { - return Op("atan2", &x, &y); - } - - static Tensor& grad(const Tensor& x, const Tensor& wrt) { - if(x.node_->op->HasAllTypes(OpProp::Nondiff) && !x.node_->flags.has(NodeProp::Modified)) { - throw std::runtime_error("Cannot compute gradient of a non-differentiable operation"); - } - return OpShape("backwards_grad", wrt.GetShape(), &x, &wrt); - } - - static Tensor& lerp(const Tensor& x, const Tensor& y, const Tensor& a) { - return Op("lerp", &x, &y, &a); - } - - static Tensor& smoothstep(const Tensor& a, const Tensor& b, const Tensor& x) { - return Op("smoothstep", &a, &b, &x); - } - - static Tensor& select(const Tensor& cond, const Tensor& x, const Tensor& y) { - return Op("ternary", &cond, &x, &y); - } - - static Tensor& fma(const Tensor& x, const Tensor& y, const Tensor& z) { - return Op("fma", &x, &y, &z); - } - - static Tensors IndexGrid(const Tensors& begin, const Tensors& end) { - //compute shape - Tensors shape = Tensors(); - for (int i = 0; i < begin.size(); i++) { - shape.push_back(&(*end[i] - *begin[i])); - } - //compute indices - Tensors index_grid = Tensors(); - for (int i = 0; i < begin.size(); i++) { - index_grid.push_back(&(Index(shape, i) + *begin[i])); - } - return index_grid; - } - - static Tensors IndexGrid(const Tensors& begin, const Tensors& end, const Tensors& step) - { - Tensors shape = Tensors(); - for (int i = 0; i < begin.size(); i++) { - shape.push_back(&((*end[i] - *begin[i]) / *step[i])); - } - //compute indices - Tensors index_grid = Tensors(); - for (int i = 0; i < begin.size(); i++) { - index_grid.push_back(&(Index(shape, i) * *step[i] + *begin[i])); - } - return index_grid; - } - - static void PrintValue(std::string name, const Tensor& tensor) { - if (tensor.GetDimension() != 0) { - throw std::runtime_error("Cannot print a non-scalar value"); - } - Tensor& output = Op("print_value", &tensor); - output.SetDebugName(name); - } - - static void AssertValue(std::string name, const Tensor& tensor) { - if (tensor.GetDimension() != 0) { - throw std::runtime_error("Cannot assert a non-scalar value"); - } - Tensor& output = Op("assert_value", &tensor); - output.SetDebugName(name); - } - - void SetDebugName(const string& name) const; - - static void BeginRegion(const string& name); - static void EndRegion(const string& name); - - int axis(int i = 0) const { - return (int)node_->data[i]; - } - - uint data(int i = 0) const { - return node_->data[i]; - } -}; - -} // namespace TensorFrost diff --git a/TensorFrost/Tensor/TensorProgram.cpp b/TensorFrost/Tensor/TensorProgram.cpp deleted file mode 100644 index 8d29de7e..00000000 --- a/TensorFrost/Tensor/TensorProgram.cpp +++ /dev/null @@ -1,89 +0,0 @@ -#include -#include -#include "TensorProgram.h" - -namespace TensorFrost { - -void TensorProgram::CreateProgram(string name) { - Tensor::SetEvaluationContext(nullptr); - - //get current time - auto start = std::chrono::high_resolution_clock::now(); - - // create new IR graph - Tensor::SetEvaluationContext(&ir); - Tensor::BeginRegion(name); - Tensors outputs = evaluate_callback(); - Tensor::EndRegion(name); - // set outputs - for (int i = 0; i < outputs.size(); i++) { - outputs[i]->SetMemoryType(NodeProp::OutputMemory, i); - } - ir.output_memory_count = (int)outputs.size(); - - if (outputs.size() == 0) { - throw std::runtime_error("TensorProgram does not do any computation: no outputs"); - } - - program = GenerateProgram(&ir); - program->program_name = name; - - Tensor::SetEvaluationContext(nullptr); - - auto end = std::chrono::high_resolution_clock::now(); - - compile_time = std::chrono::duration(end - start).count(); - - start = std::chrono::high_resolution_clock::now(); - - GenerateCode(program); - - end = std::chrono::high_resolution_clock::now(); - - codegen_time = std::chrono::duration(end - start).count(); - - if (current_backend != BackendType::CodeGen) // no need to compile if we are in codegen mode - { - auto start_time = chrono::high_resolution_clock::now(); - CompileAndLoadKernelModule(program, program_id); - auto end_time = chrono::high_resolution_clock::now(); - host_compile_time = chrono::duration(end_time - start_time).count(); - - start_time = chrono::high_resolution_clock::now(); - CompileKernels(program); - end_time = chrono::high_resolution_clock::now(); - shader_compile_time = chrono::duration(end_time - start_time).count(); - } -} - -vector TensorProgram::Evaluate( - const vector& input) const { - return ExecuteProgram(program, input); -} - -string TensorProgram::PrintProperties() const { - string properties = program->program_name + ":\n"; - int compute_kernels = (int)program->kernels_.size(); - int lines = 0; - string line; - istringstream stream(program->generated_code_); - while (getline(stream, line)) { - lines++; - } - properties += " Kernel count: " + to_string(compute_kernels) + "\n"; - properties += " Intermediate buffers: " + to_string(ir.temp_memory_count) + "\n"; - properties += " Host readbacks: " + to_string(ir.readbacks) + "\n"; - properties += " Host writes: " + to_string(ir.writebacks) + "\n"; - properties += " Lines of generated code: " + to_string(lines) + "\n"; - properties += " IR Compile time: " + to_string(compile_time) + " ms\n"; - properties += " Codegen time: " + to_string(codegen_time) + " ms\n"; - if(host_compile_time > 0.01f) - properties += " Host Compile time: " + to_string(host_compile_time) + " ms\n"; - if (shader_compile_time > 0.01f) - properties += " Shader Compile time: " + to_string(shader_compile_time) + " ms\n"; - return properties; -} - -size_t TensorProgram::program_id = 0; - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Tensor/TensorProgram.h b/TensorFrost/Tensor/TensorProgram.h deleted file mode 100644 index 32abf0d1..00000000 --- a/TensorFrost/Tensor/TensorProgram.h +++ /dev/null @@ -1,46 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "Backend/Backend.h" -#include "Compiler/KernelGen.h" -#include "Tensor.h" - -namespace TensorFrost { - -using namespace std; - -class TensorProgram { - public: - static size_t program_id; - using EvaluateFunction = function; - EvaluateFunction evaluate_callback; - IR ir; - Program* program; - bool debug = false; - float compile_time = 0.0f; - float codegen_time = 0.0f; - float host_compile_time = 0.0f; - float shader_compile_time = 0.0f; - - explicit TensorProgram(EvaluateFunction evaluate, string name) : evaluate_callback(std::move(evaluate)) { - CreateProgram(name); - program_id++; - } - - void CreateProgram(string name); - - vector Evaluate( - const vector& input) const; - - string PrintProperties() const; - - ~TensorProgram() { delete program; } -}; - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/TensorFrost.h b/TensorFrost/TensorFrost.h deleted file mode 100644 index 67d1accb..00000000 --- a/TensorFrost/TensorFrost.h +++ /dev/null @@ -1,10 +0,0 @@ -#pragma once - -#ifdef _RELWITHDEBINFO -#define PYBIND11_DETAILED_ERROR_MESSAGES -#endif - -#include "Backend/Backend.h" -#include "Compiler/KernelGen.h" -#include "Tensor/Tensor.h" -#include "Tensor/TensorProgram.h" diff --git a/TensorFrost/Utility/Utility.cpp b/TensorFrost/Utility/Utility.cpp deleted file mode 100644 index 461ee402..00000000 --- a/TensorFrost/Utility/Utility.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include "Utility.h" - -namespace TensorFrost { - -int GetSize(const vector& shape) { - int size = 1; - for (int i : shape) { - size *= i; - } - return size; -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Utility/Utility.h b/TensorFrost/Utility/Utility.h deleted file mode 100644 index 72d678d4..00000000 --- a/TensorFrost/Utility/Utility.h +++ /dev/null @@ -1,132 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace TensorFrost { -using namespace std; -using uint = unsigned int; - -inline uint AsUint(float f) { return *reinterpret_cast(&f); } -inline uint AsUint(int i) { return *reinterpret_cast(&i); } -inline float AsFloat(uint i) { return *reinterpret_cast(&i); } -inline float AsFloat(int i) { return *reinterpret_cast(&i); } -inline int AsInt(float f) { return *reinterpret_cast(&f); } -inline int AsInt(uint i) { return *reinterpret_cast(&i); } - -int GetSize(const vector& shape); - -template -class FlagSet { - array data; - -public: - FlagSet() { - data.fill(-1); - } - - void set(T flag) { - data[(int)flag] = 0; - } - - template - void set(T flag, Args... args) { - set(flag); - set(args...); - } - - void set(T flag, bool value) { - data[(int)flag] = value ? 0 : -1; - } - - void set(T flag, int64_t value) { - if(value < 0) { - throw std::runtime_error("Flag data must be non-negative"); - } - data[(int)flag] = value + 1; - } - - void remove(T flag) { - data[(int)flag] = -1; - } - - template - void remove(T flag, Args... args) { - remove(flag); - remove(args...); - } - - void clear() { - data.fill(-1); - } - - bool has(T flag) const { - return data[(int)flag] != -1; - } - - template - bool has(T flag, Args... args) const { - return has(flag) && has(args...); - } - - int64_t get(T flag, bool throw_error = true) const { - int64_t res = data[(int)flag] - 1; - if (throw_error && res < 0) { - throw std::runtime_error("Flag data is not set"); - } - return res; - } - - void copy_all_given(const FlagSet& other, unordered_set only) { - for (int i = 0; i < N; i++) { - data[i] = -1; - T flag = (T)i; - if (only.contains(flag)) { - data[(int)flag] = other.data[(int)flag]; - } - } - } - - void copy_all_except(const FlagSet& other, unordered_set except) { - for (int i = 0; i < N; i++) { - data[i] = -1; - T flag = (T)i; - if (!except.contains(flag)) { - data[(int)flag] = other.data[(int)flag]; - } - } - } - - void copy_all(const FlagSet& other) { - for (int i = 0; i < N; i++) { - data[i] = other.data[i]; - } - } - - size_t count() const { - size_t res = 0; - for (int i = 0; i < N; i++) { - if (data[i] != -1) { - res++; - } - } - return res; - } - - unordered_map get_data() const { - unordered_map res; - for (int i = 0; i < N; i++) { - T flag = (T)i; - if (has(flag)) { - res[flag] = data[i] - 1; - } - } - return res; - } -}; - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/IR/include/Common.h b/TensorFrost/include/Compiler/Common.h similarity index 89% rename from TensorFrost/IR/include/Common.h rename to TensorFrost/include/Compiler/Common.h index 5587025d..9b0d53d8 100644 --- a/TensorFrost/IR/include/Common.h +++ b/TensorFrost/include/Compiler/Common.h @@ -44,11 +44,11 @@ extern "C" { template class auto_vector : public std::vector { public: - void set_element(size_t index, const T& value) { + void set_element(size_t index, T&& value) { if (index >= this->size()) { this->resize(index + 1); } - (*this)[index] = value; + (*this)[index] = std::forward(value); } }; @@ -107,3 +107,12 @@ using Attribute = std::variant; using AttributeMap = std::unordered_map; } + +namespace std { +template<> +struct hash { + size_t operator()(const TensorFrost::TFDataFormat& f) const noexcept { + return static_cast(f.GetHash()); + } +}; +} diff --git a/TensorFrost/IR/include/ExecutionContext.h b/TensorFrost/include/Compiler/ExecutionContext.h similarity index 100% rename from TensorFrost/IR/include/ExecutionContext.h rename to TensorFrost/include/Compiler/ExecutionContext.h diff --git a/TensorFrost/IR/include/Operation.h b/TensorFrost/include/Compiler/Operation.h similarity index 72% rename from TensorFrost/IR/include/Operation.h rename to TensorFrost/include/Compiler/Operation.h index 7d259908..b006ec3e 100644 --- a/TensorFrost/IR/include/Operation.h +++ b/TensorFrost/include/Compiler/Operation.h @@ -1,6 +1,11 @@ #pragma once #include "Common.h" +#include "OperationArguments.h" +#include "OperationBlocks.h" +#include "OperationRegistry.h" +#include "ExecutionContext.h" +#include "Overloads.h" namespace TensorFrost { diff --git a/TensorFrost/IR/include/OperationArguments.h b/TensorFrost/include/Compiler/OperationArguments.h similarity index 100% rename from TensorFrost/IR/include/OperationArguments.h rename to TensorFrost/include/Compiler/OperationArguments.h diff --git a/TensorFrost/IR/include/OperationBlocks.h b/TensorFrost/include/Compiler/OperationBlocks.h similarity index 100% rename from TensorFrost/IR/include/OperationBlocks.h rename to TensorFrost/include/Compiler/OperationBlocks.h diff --git a/TensorFrost/IR/include/OperationRegistry.h b/TensorFrost/include/Compiler/OperationRegistry.h similarity index 100% rename from TensorFrost/IR/include/OperationRegistry.h rename to TensorFrost/include/Compiler/OperationRegistry.h diff --git a/TensorFrost/include/Compiler/Overloads.h b/TensorFrost/include/Compiler/Overloads.h new file mode 100644 index 00000000..bdb94c3f --- /dev/null +++ b/TensorFrost/include/Compiler/Overloads.h @@ -0,0 +1,149 @@ +#pragma once +#include "Operation.h" + +namespace TensorFrost { + +Op& make_op(std::string op, std::vector mem, std::vector ids, std::vector args, std::vector shape); + +template +Op& func_op(std::string op, const Args&... args) { + std::vector mem; + std::vector ids; + std::vector args_vec = {&args...}; + std::vector shape; + return make_op(op, mem, ids, args_vec, shape); +} + +Op& constant(int value); +Op& constant(uint value); +Op& constant(float value); +Op& constant(bool value); + +template concept Num = std::is_arithmetic_v>; +template concept IsOp = std::same_as, Op>; + +template +inline Op& as_op(T v) { + using D = std::remove_cvref_t; + using Target = + std::conditional_t, bool, + std::conditional_t, float, + std::conditional_t, unsigned int, + int>>>; + return constant(static_cast(v)); +} +inline Op& as_op(const Op& x) { return const_cast(x); } + +#define UNARY_OPERATOR(op, opname) \ +template \ +requires IsOp \ +inline Op& operator op(const T& x) { \ + return func_op(opname, as_op(x)); \ +} + +#define BINARY_OPERATOR(op, opname) \ +template \ +requires (IsOp || IsOp) \ +inline Op& operator op(const T& x, const U& y) { \ + return func_op(opname, as_op(x), as_op(y)); \ +} + +#define UNARY_FUNCTION(func, opname) \ +template \ +requires IsOp \ +inline Op& func(const T& x) { \ + return func_op(opname, as_op(x)); \ +} + +#define BINARY_FUNCTION(func, opname) \ +template \ +requires (IsOp || IsOp) \ +inline Op& func(const T& x, const U& y) { \ + return func_op(opname, as_op(x), as_op(y)); \ +} + +#define TERNARY_FUNCTION(func, opname) \ +template \ +requires (IsOp || IsOp || IsOp) \ +inline Op& func(const T& x, const U& y, const V& z) { \ + return func_op(opname, as_op(x), as_op(y), as_op(z)); \ +} + +UNARY_OPERATOR(+, "pos") +UNARY_OPERATOR(-, "neg") +UNARY_OPERATOR(~, "not") +UNARY_OPERATOR(!, "lnot") + +BINARY_OPERATOR(+, "add") +BINARY_OPERATOR(-, "sub") +BINARY_OPERATOR(*, "mul") +BINARY_OPERATOR(/, "div") +BINARY_OPERATOR(%, "mod") +BINARY_OPERATOR(&, "and") +BINARY_OPERATOR(|, "or") +BINARY_OPERATOR(^, "xor") +BINARY_OPERATOR(<<, "lshift") +BINARY_OPERATOR(>>, "rshift") +BINARY_OPERATOR(==, "eq") +BINARY_OPERATOR(!=, "neq") +BINARY_OPERATOR(<, "lt") +BINARY_OPERATOR(<=, "lte") +BINARY_OPERATOR(>, "gt") +BINARY_OPERATOR(>=, "gte") +BINARY_OPERATOR(&&, "land") +BINARY_OPERATOR(||, "lor") + +UNARY_FUNCTION(copy, "copy") +UNARY_FUNCTION(sin, "sin") +UNARY_FUNCTION(cos, "cos") +UNARY_FUNCTION(tan, "tan") +UNARY_FUNCTION(asin, "asin") +UNARY_FUNCTION(acos, "acos") +UNARY_FUNCTION(atan, "atan") +UNARY_FUNCTION(sinh, "sinh") +UNARY_FUNCTION(cosh, "cosh") +UNARY_FUNCTION(tanh, "tanh") +UNARY_FUNCTION(asinh, "asinh") +UNARY_FUNCTION(acosh, "acosh") +UNARY_FUNCTION(atanh, "atanh") +UNARY_FUNCTION(exp, "exp") +UNARY_FUNCTION(log, "log") +UNARY_FUNCTION(log2, "log2") +UNARY_FUNCTION(exp2, "exp2") +UNARY_FUNCTION(sqrt, "sqrt") +UNARY_FUNCTION(sqr, "sqr") +UNARY_FUNCTION(rsqrt, "rsqrt") +UNARY_FUNCTION(rcp, "rcp") +UNARY_FUNCTION(abs, "abs") +UNARY_FUNCTION(sign, "sign") +UNARY_FUNCTION(floor, "floor") +UNARY_FUNCTION(ceil, "ceil") +UNARY_FUNCTION(round, "round") +UNARY_FUNCTION(trunc, "trunc") +UNARY_FUNCTION(frac, "frac") +UNARY_FUNCTION(pcg, "pcg") +UNARY_FUNCTION(pcgf, "pcgf") +UNARY_FUNCTION(reversebits, "reversebits") +UNARY_FUNCTION(tofloat, "tofloat") +UNARY_FUNCTION(toint, "toint") +UNARY_FUNCTION(touint, "touint") +UNARY_FUNCTION(tobool, "tobool") +UNARY_FUNCTION(asfloat, "asfloat") +UNARY_FUNCTION(asint, "asint") +UNARY_FUNCTION(asuint, "asuint") +UNARY_FUNCTION(clamp, "clamp") + +BINARY_FUNCTION(pow, "pow") +BINARY_FUNCTION(min, "min") +BINARY_FUNCTION(max, "max") +BINARY_FUNCTION(mod, "mod") +BINARY_FUNCTION(modf, "modf") +BINARY_FUNCTION(atan2, "atan2") +BINARY_FUNCTION(grad, "backwards_grad") + +TERNARY_FUNCTION(lerp, "lerp") +TERNARY_FUNCTION(smoothstep, "smoothstep") +TERNARY_FUNCTION(select, "ternary") +TERNARY_FUNCTION(fma, "fma") + +} diff --git a/TensorFrost/IR/src/Common.cpp b/TensorFrost/src/Compiler/Common.cpp similarity index 89% rename from TensorFrost/IR/src/Common.cpp rename to TensorFrost/src/Compiler/Common.cpp index b0fc8c64..b78a0a84 100644 --- a/TensorFrost/IR/src/Common.cpp +++ b/TensorFrost/src/Compiler/Common.cpp @@ -1,7 +1,6 @@ -#include "../include/Common.h" - -using namespace TensorFrost; +#include "Compiler/Common.h" +namespace TensorFrost { bool TFDataFormat::operator==(const TFDataFormat &other) const { return type == other.type && size == other.size; } @@ -21,4 +20,4 @@ bool TFDataFormat::operator<(const TFDataFormat &other) const { bool TFDataFormat::operator>(const TFDataFormat &other) const { return GetHash() > other.GetHash(); } - +} diff --git a/TensorFrost/IR/src/ExecutionContext.cpp b/TensorFrost/src/Compiler/ExecutionContext.cpp similarity index 88% rename from TensorFrost/IR/src/ExecutionContext.cpp rename to TensorFrost/src/Compiler/ExecutionContext.cpp index bc941210..1e1c6a9f 100644 --- a/TensorFrost/IR/src/ExecutionContext.cpp +++ b/TensorFrost/src/Compiler/ExecutionContext.cpp @@ -1,8 +1,8 @@ -#include "../include/ExecutionContext.h" -#include "../include/Operation.h" - -using namespace TensorFrost; +#include "Compiler/ExecutionContext.h" +#include "Compiler/Operation.h" +#include "Compiler/OperationBlocks.h" +namespace TensorFrost { ExecutionContext::ExecutionContext(): base_block(std::make_unique()), current_block(base_block.get()) {} void ExecutionContext::BeginBlock(Op *op) { @@ -44,4 +44,4 @@ void EndExecutionContext() { delete current_context; current_context = nullptr; } - +} \ No newline at end of file diff --git a/TensorFrost/IR/src/Operation.cpp b/TensorFrost/src/Compiler/Operation.cpp similarity index 79% rename from TensorFrost/IR/src/Operation.cpp rename to TensorFrost/src/Compiler/Operation.cpp index 8ef73e6f..96e96914 100644 --- a/TensorFrost/IR/src/Operation.cpp +++ b/TensorFrost/src/Compiler/Operation.cpp @@ -1,9 +1,6 @@ -#include "../include/Operation.h" -#include "../include/OperationArguments.h" -#include "../include/Overloads.h" - -using namespace TensorFrost; +#include "Compiler/Operation.h" +namespace TensorFrost { Op::Op(std::string op_name): opcode(std::move(op_name)) { args = std::make_unique(this); type = TFTypeNone; @@ -28,3 +25,4 @@ Op::Op(bool value) : Op(std::string("const")) { attributes["value"] = value; type = TFTypeBool32; } +} \ No newline at end of file diff --git a/TensorFrost/IR/src/OperationArguments.cpp b/TensorFrost/src/Compiler/OperationArguments.cpp similarity index 92% rename from TensorFrost/IR/src/OperationArguments.cpp rename to TensorFrost/src/Compiler/OperationArguments.cpp index 58e75438..1edc31c6 100644 --- a/TensorFrost/IR/src/OperationArguments.cpp +++ b/TensorFrost/src/Compiler/OperationArguments.cpp @@ -1,8 +1,6 @@ -#include "../include/OperationArguments.h" -#include "../include/Operation.h" - -using namespace TensorFrost; +#include "Compiler/Operation.h" +namespace TensorFrost { void Arguments::AddInput(ArgType type, Op *from, int index) { inputs.set_element(index, std::make_unique(Argument{type, from, parent_op, index})); from->args->SetAsOutput(inputs[index].get()); @@ -42,3 +40,4 @@ void ArgumentManager::SetArguments(ArgType type, std::vector args) { AddArgument(args[i], type, (int)i); } } +} \ No newline at end of file diff --git a/TensorFrost/IR/src/OperationBlocks.cpp b/TensorFrost/src/Compiler/OperationBlocks.cpp similarity index 97% rename from TensorFrost/IR/src/OperationBlocks.cpp rename to TensorFrost/src/Compiler/OperationBlocks.cpp index 27ef0e3c..6af9901f 100644 --- a/TensorFrost/IR/src/OperationBlocks.cpp +++ b/TensorFrost/src/Compiler/OperationBlocks.cpp @@ -1,7 +1,6 @@ -#include "../include/OperationBlocks.h" - -using namespace TensorFrost; +#include "Compiler/Operation.h" +namespace TensorFrost { Op* OpBlock::append(std::unique_ptr op) { ops.emplace_back(std::move(op)); return ops.back().get(); @@ -80,4 +79,5 @@ bool OpBlockIterator::up() { stack.pop_back(); current_op = stack.back().it->get(); return true; +} } \ No newline at end of file diff --git a/TensorFrost/IR/src/OperationRegistry.cpp b/TensorFrost/src/Compiler/OperationRegistry.cpp similarity index 78% rename from TensorFrost/IR/src/OperationRegistry.cpp rename to TensorFrost/src/Compiler/OperationRegistry.cpp index 50b9df73..a33ca7ae 100644 --- a/TensorFrost/IR/src/OperationRegistry.cpp +++ b/TensorFrost/src/Compiler/OperationRegistry.cpp @@ -1,9 +1,8 @@ -#include "../include/Operation.h" -#include "../include/OperationRegistry.h" +#include "Compiler/Operation.h" -using namespace TensorFrost; using namespace std; +namespace TensorFrost { OpSpec::OpSpec(std::string op_name, OverloadsMap overloads_list) { name = std::move(op_name); overloads = std::move(overloads_list); @@ -18,12 +17,12 @@ TFDataFormat OpSpec::GetOutputType(const std::vector &args) const } static const std::unordered_map tok = { - {"f", TFDataFormat::TFTypeFloat32}, - {"i", TFDataFormat::TFTypeInt32}, - {"u", TFDataFormat::TFTypeUint32}, - {"tuple", TFDataFormat::TFTypeTuple}, - {"b", TFDataFormat::TFTypeBool32}, - {"void", TFDataFormat::TFTypeNone}, + {"f", TFTypeFloat32}, + {"i", TFTypeInt32}, + {"u", TFTypeUint32}, + {"tuple", TFTypeTuple}, + {"b", TFTypeBool32}, + {"void", TFTypeNone}, }; static std::string trim(std::string_view s) { @@ -65,26 +64,27 @@ vector default_operations = { OpSpec("parallel", ovr("tuple()")), }; -std::unordered_map CreateOperationRegistry() { - std::unordered_map registry; +std::unordered_map> CreateOperationRegistry() { + std::unordered_map> registry; for (const auto& op : default_operations) { - registry[op.name] = op; + registry[op.name] = std::make_unique(op); } return registry; } -std::unordered_map operation_registry = CreateOperationRegistry(); +std::unordered_map> operation_registry = CreateOperationRegistry(); void TensorFrost::RegisterOperation(const OpSpec &spec) { if (operation_registry.contains(spec.name)) { throw std::runtime_error("Operation already registered: " + spec.name); } - operation_registry[spec.name] = spec; + operation_registry[spec.name] = std::make_unique(spec); } OpSpec* TensorFrost::GetOpSpec(const std::string &name) { if (!operation_registry.contains(name)) { throw std::runtime_error("Operation not found: " + name); } - return &operation_registry[name]; + return operation_registry[name].get(); } +} \ No newline at end of file diff --git a/TensorFrost/src/Compiler/Overloads.cpp b/TensorFrost/src/Compiler/Overloads.cpp new file mode 100644 index 00000000..cfdc4b05 --- /dev/null +++ b/TensorFrost/src/Compiler/Overloads.cpp @@ -0,0 +1,51 @@ +#include "Compiler/Operation.h" +#include "Compiler/ExecutionContext.h" + +using namespace std; + +namespace TensorFrost { +// General function to create an Op instance in the current execution context +Op& make_op(std::string op, std::vector mem, std::vector ids, std::vector args, std::vector shape) { + OpSpec* spec = GetOpSpec(op); + vector arg_types; + for (const auto& arg : args) { + arg_types.push_back(arg->type); + } + TFDataFormat output_type = spec->GetOutputType(arg_types); + Op* op_instance = new Op(op); + op_instance->type = output_type; + op_instance->args->SetArguments(ArgType::Memory, mem); + op_instance->args->SetArguments(ArgType::Index, ids); + op_instance->args->SetArguments(ArgType::Input, args); + op_instance->args->SetArguments(ArgType::Shape, shape); + return GetContext()->AddOp(std::unique_ptr(op_instance)); +} + +Op& constant(int value) { + Op& const_op = func_op("const"); + const_op.attributes["value"] = value; + const_op.type = TFTypeInt32; + return const_op; +} + +Op& constant(uint value) { + Op& const_op = func_op("const"); + const_op.attributes["value"] = value; + const_op.type = TFTypeUint32; + return const_op; +} + +Op& constant(float value) { + Op& const_op = func_op("const"); + const_op.attributes["value"] = value; + const_op.type = TFTypeFloat32; + return const_op; +} + +Op& constant(bool value) { + Op& const_op = func_op("const"); + const_op.attributes["value"] = value; + const_op.type = TFTypeBool32; + return const_op; +} +} \ No newline at end of file diff --git a/TensorFrost/src/Definitions/PyModule.cpp b/TensorFrost/src/Definitions/PyModule.cpp new file mode 100644 index 00000000..d5100867 --- /dev/null +++ b/TensorFrost/src/Definitions/PyModule.cpp @@ -0,0 +1,62 @@ +// #include +// #include +// +// #include +// #include +// +// namespace TensorFrost { +// +// class PyModule : public Module { +// public: +// using Module::Module; // Inherit constructors +// +// void assert_parameters() override { +// PYBIND11_OVERRIDE(void, Module, assert_parameters); +// } +// +// py::object loss(py::object X, py::object Y) override { +// PYBIND11_OVERRIDE_PURE(py::object, Module, loss, X, Y); +// } +// +// py::object forward(py::object X) override { +// PYBIND11_OVERRIDE_PURE(py::object, Module, forward, X); +// } +// }; +// +// void ModuleDefinitions(py::module& m) { +// py::class_(m, "Parameter") +// .def(py::init&, TFDataFormat, float, float, bool>(), py::arg("shape"), py::arg("dtype") = TFType::Float, py::arg("random_scale") = -1.0f, py::arg("random_offset") = 0.0f, py::arg("optimize") = true) +// .def_readwrite("shape", &Parameter::shape) +// .def_readwrite("dtype", &Parameter::dtype) +// .def_readwrite("random_scale", &Parameter::random_scale) +// .def_readwrite("random_offset", &Parameter::random_offset) +// .def("__repr__", [](const Parameter& p) { +// return "Parameter(shape=" + std::to_string(p.shape.size()) + ", dtype=" + std::to_string(p.dtype.type) + "( " + std::to_string(p.dtype.size) + ") , random_scale=" + std::to_string(p.random_scale) + ", random_offset=" + std::to_string(p.random_offset) + ", optimize=" + std::to_string(p.optimize) + ")"; +// }); +// +// py::class_(m, "ParameterArray") +// .def(py::init<>()) +// .def("__getitem__", &ParameterArray::getitem) +// .def("__setitem__", &ParameterArray::setitem) +// .def("items", &ParameterArray::items); +// +// py::class_(m, "Module") +// .def(py::init(), py::arg("requires_grad") = true) +// .def("__getattr__", &Module::getattr) +// .def("__setattr__", &Module::setattr) +// .def("hasattr", &Module::hasattr) +// .def("param_requires_grad", &Module::param_requires_grad) +// .def("initialize_input", &Module::initialize_input) +// .def("initialize_parameters", &Module::initialize_parameters) +// .def("initialize_parameters_native", &Module::initialize_parameters_native) +// .def("parameters", &Module::parameters) +// .def("named_parameters", &Module::named_parameters) +// .def("requires_grads_list", &Module::requires_grads_list) +// .def("create_input", &Module::create_input) +// .def("update_parameters", &Module::update_parameters) +// .def("assert_parameters", &Module::assert_parameters) +// .def("loss", &Module::loss) +// .def("forward", &Module::forward); +// } +// +// } // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/src/Definitions/PyTensor.cpp b/TensorFrost/src/Definitions/PyTensor.cpp new file mode 100644 index 00000000..624e54e4 --- /dev/null +++ b/TensorFrost/src/Definitions/PyTensor.cpp @@ -0,0 +1,199 @@ +// #include +// #include +// +// #include +// +// namespace TensorFrost { +// +// void DefineOperator( +// const std::string& pyname, py::class_& py_tensor, +// const std::function& op) { +// py_tensor.def(l_op(pyname).c_str(), +// [op](const PyTensor& t, const PyTensor& t2) { +// return PT(op(T(t), T(t2))); +// }); +// py_tensor.def(l_op(pyname).c_str(), [op](const PyTensor& t, const float f) { +// return PT(op(T(t), Tensor::Constant(f))); +// }); +// py_tensor.def(l_op(pyname).c_str(), [op](const PyTensor& t, const int i) { +// return PT(op(T(t), Tensor::Constant(i))); +// }); +// py_tensor.def(r_op(pyname).c_str(), [op](const PyTensor& t, const float f) { +// return PT(op(Tensor::Constant(f), T(t))); +// }); +// py_tensor.def(r_op(pyname).c_str(), [op](const PyTensor& t, const int i) { +// return PT(op(Tensor::Constant(i), T(t))); +// }); +// } +// +// #define LAMBDA_OP(op) \ +// [](const Tensor& t1, const Tensor& t2) -> Tensor& { return t1 op t2; } +// +// void DefineOperators(py::class_& py_tensor) { +// DefineOperator("add", py_tensor, LAMBDA_OP(+)); +// DefineOperator("sub", py_tensor, LAMBDA_OP(-)); +// DefineOperator("mul", py_tensor, LAMBDA_OP(*)); +// DefineOperator("div", py_tensor, LAMBDA_OP(/)); +// DefineOperator("truediv", py_tensor, LAMBDA_OP(/)); +// DefineOperator("mod", py_tensor, LAMBDA_OP(%)); +// DefineOperator("eq", py_tensor, LAMBDA_OP(==)); +// DefineOperator("ne", py_tensor, LAMBDA_OP(!=)); +// DefineOperator("lt", py_tensor, LAMBDA_OP(<)); +// DefineOperator("le", py_tensor, LAMBDA_OP(<=)); +// DefineOperator("gt", py_tensor, LAMBDA_OP(>)); +// DefineOperator("ge", py_tensor, LAMBDA_OP(>=)); +// DefineOperator("and", py_tensor, LAMBDA_OP(&&)); +// DefineOperator("or", py_tensor, LAMBDA_OP(||)); +// DefineOperator("xor", py_tensor, LAMBDA_OP(^)); +// DefineOperator("lshift", py_tensor, LAMBDA_OP(<<)); +// DefineOperator("rshift", py_tensor, LAMBDA_OP(>>)); +// DefineOperator("and_", py_tensor, LAMBDA_OP(&)); +// DefineOperator("or_", py_tensor, LAMBDA_OP(|)); +// +// py_tensor.def("__neg__", [](const PyTensor& t) { return PT(-T(t)); }); +// py_tensor.def("__not__", [](const PyTensor& t) { return PT(!T(t)); }); +// py_tensor.def("__invert__", [](const PyTensor& t) { return PT(~T(t)); }); +// py_tensor.def("__pow__", [](const PyTensor& t, const PyTensor& t2) { +// return PT(Tensor::pow(T(t), T(t2))); +// }); +// py_tensor.def("__pow__", [](const PyTensor& t, float f) { +// return PT(Tensor::pow(T(t), Tensor::Constant(f))); +// }); +// py_tensor.def("__rpow__", [](const PyTensor& t, float f) { +// return PT(Tensor::pow(Tensor::Constant(f), T(t))); +// }); +// py_tensor.def("__matmul__", [](const PyTensor& t, const PyTensor& t2) { +// return PT(Tensor::Matmul(T(t), T(t2))); +// }); +// } +// +// void PyTensorDefinition(py::module& /*m*/, py::class_& py_tensor) { +// // initializers +// py_tensor.def(py::init()); +// py_tensor.def(py::init()); +// py_tensor.def(py::init()); +// py_tensor.def(py::init()); +// +// // properties +// py_tensor.def_property_readonly("shape", [](const PyTensor& t) { +// return PyTensorsFromTensors(Reverse(t.Get().GetShape())); +// }); +// py_tensor.def_property_readonly( +// "type", [](const PyTensor& t) { return t.Get().GetFormat(); }); +// py_tensor.def_property_readonly("indices", [](const PyTensor& t) { +// int dim = T(t).GetDimension(); +// py::tuple indices(dim); +// for (int i = 0; i < dim; i++) { +// indices[i] = PT(T(t).Index(dim - i - 1)); +// } +// return indices; +// }); +// py_tensor.def_property_readonly("op_name", [](const PyTensor& t) { +// return T(t).node_->name; +// }); +// py_tensor.def("try_get_constant", [](const PyTensor& t) { +// if(T(t).node_->name != "const") { +// throw std::runtime_error("Can not get constant from non-constant tensor"); +// } +// return T(t).TryGetConstant(); +// }); +// py_tensor.def("index",[](const PyTensor& t, int dim) { +// int dims = T(t).GetDimension(); +// return PT(T(t).Index(dims - dim - 1)); +// }); +// +// py_tensor.def("block_index", [](const PyTensor& t) { +// return PT(T(t).BlockIndex()); +// }); +// +// py_tensor.def("block_thread_index", [](const PyTensor& t, int block_dim) { +// return PT(T(t).BlockThreadIndex(block_dim)); +// }); +// +// py_tensor.def("detach_grad", [](const PyTensor& t) { +// t.Get().DetachGrad(); +// return t; +// }); +// +// py_tensor.def("pass_grad", [](const PyTensor& t) { +// t.Get().PassGrad(); +// return t; +// }); +// +// py_tensor.def("stop_fusion", [](const PyTensor& t) { +// t.Get().StopFusion(); +// return t; +// }); +// +// py_tensor.def("hint_range", [](const PyTensor& t, py::object min, py::object max) { +// if(t.Get().node_->format == TFTypeFloat32) { +// t.Get().HintRange(py::cast(min), py::cast(max)); +// } else { +// t.Get().HintRange(py::cast(min), py::cast(max)); +// } +// }, py::arg("min"), py::arg("max")); +// +// // operators +// DefineOperators(py_tensor); +// +// //no way to overload normal setter +// //TODO use python AST to generate these functions +// py_tensor.def("set", +// [](const PyTensor& t, const PyTensor& t2) { T(t).Set(T(t2)); }); +// +// py_tensor.def_property("val", [](const PyTensor& t) { return t; }, +// [](PyTensor& t, const PyTensor& val) { T(t).Set(T(val)); }); +// +// // indexing +// py_tensor.def("__getitem__", [](const PyTensor& t, const PyTensor& t1) { +// Tensors indices; +// indices.push_back(&t1.Get()); +// return PyTensor(&t.Get(), indices); +// }); +// py_tensor.def("__getitem__", [](const PyTensor& t, py::tuple indices_tuple) { +// Tensors indices = Reverse(TensorsFromTuple(indices_tuple)); +// return PyTensor(&t.Get(), indices); +// }); +// +// py_tensor.def("__setitem__", +// [](const PyTensor& t, const PyTensor& t1, const PyTensor& t2) { +// Tensors indices; +// indices.push_back(&t1.Get()); +// Tensor::Store(t.Get(), T(t2), indices); +// }); +// py_tensor.def("__setitem__", [](const PyTensor& t, py::tuple indices_tuple, +// const PyTensor& t2) { +// Tensors indices = Reverse(TensorsFromTuple(indices_tuple)); +// Tensor::Store(t.Get(), T(t2), indices); +// }); +// +// py_tensor.def("__setitem__", [](const PyTensor& t, const PyTensor& t1, pybind11::none none) { +// //do nothing +// }); +// py_tensor.def("__setitem__", [](const PyTensor& t, py::tuple indices_tuple, pybind11::none none) { +// //do nothing +// }); +// +// // transpose +// py_tensor.def("transpose", [](const PyTensor& t, int dim1, int dim2) { +// return PT(Tensor::Transpose(T(t), -dim1-1, -dim2-1)); +// }, py::arg("dim1") = -2, py::arg("dim2") = -1, "Transpose the tensor"); +// +// //transpose property +// py_tensor.def_property_readonly("T", [](const PyTensor& t) { +// return PT(Tensor::Transpose(T(t))); +// }); +// +// py_tensor.def("__str__", [](const PyTensor& t) { +// return GetNodeString(t.Get().node_); +// }); +// py_tensor.def("__repr__", [](const PyTensor& t) { +// return GetNodeString(t.Get().node_); +// }); +// +// py_tensor.def("set_debug_name", [](const PyTensor& t, const std::string& name) { +// t.Get().SetDebugName(name); +// }); +// } +// +// } // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/src/Definitions/TensorFunctions.cpp b/TensorFrost/src/Definitions/TensorFunctions.cpp new file mode 100644 index 00000000..6f559ed1 --- /dev/null +++ b/TensorFrost/src/Definitions/TensorFunctions.cpp @@ -0,0 +1,353 @@ +// #include +// #include +// +// #include +// +// namespace TensorFrost { +// +// #define UNARY_FUNCTION(name) \ +// m.def(#name, [](const PyTensor& t) { return PT(Tensor::name(T(t))); }) +// +// #define BINARY_FUNCTION(name) \ +// m.def(#name, [](const PyTensor& t, const PyTensor& t2) { \ +// return PT(Tensor::name(T(t), T(t2))); \ +// }) +// +// #define TERNARY_FUNCTION(name) \ +// m.def(#name, [](const PyTensor& t, const PyTensor& t2, const PyTensor& t3) { \ +// return PT(Tensor::name(T(t), T(t2), T(t3))); \ +// }) +// +// void TensorFunctionsDefinition(py::module& m) { +// UNARY_FUNCTION(copy); +// UNARY_FUNCTION(abs); +// UNARY_FUNCTION(ceil); +// UNARY_FUNCTION(floor); +// UNARY_FUNCTION(round); +// UNARY_FUNCTION(trunc); +// UNARY_FUNCTION(sign); +// UNARY_FUNCTION(frac); +// UNARY_FUNCTION(sin); +// UNARY_FUNCTION(cos); +// UNARY_FUNCTION(tan); +// UNARY_FUNCTION(asin); +// UNARY_FUNCTION(acos); +// UNARY_FUNCTION(atan); +// UNARY_FUNCTION(sinh); +// UNARY_FUNCTION(cosh); +// UNARY_FUNCTION(tanh); +// UNARY_FUNCTION(exp); +// UNARY_FUNCTION(exp2); +// UNARY_FUNCTION(log); +// UNARY_FUNCTION(log2); +// UNARY_FUNCTION(sqrt); +// UNARY_FUNCTION(sqr); +// UNARY_FUNCTION(rsqrt); +// UNARY_FUNCTION(rcp); +// +// UNARY_FUNCTION(pcg); +// UNARY_FUNCTION(pcgf); +// UNARY_FUNCTION(reversebits); +// +// m.def("float", [](const PyTensor& t) { return PT(Tensor::tofloat(T(t))); }); +// m.def("uint", [](const PyTensor& t) { return PT(Tensor::touint(T(t))); }); +// m.def("int", [](const PyTensor& t) { return PT(Tensor::toint(T(t))); }); +// m.def("bool", [](const PyTensor& t) { return PT(Tensor::tobool(T(t))); }); +// +// m.def("asfloat", [](const PyTensor& t) { return PT(Tensor::asfloat(T(t))); }); +// m.def("asuint", [](const PyTensor& t) { return PT(Tensor::asuint(T(t))); }); +// m.def("asint", [](const PyTensor& t) { return PT(Tensor::asint(T(t))); }); +// +// BINARY_FUNCTION(min); +// BINARY_FUNCTION(max); +// BINARY_FUNCTION(pow); +// BINARY_FUNCTION(atan2); +// BINARY_FUNCTION(modf); +// +// BINARY_FUNCTION(grad); +// +// TERNARY_FUNCTION(clamp); +// TERNARY_FUNCTION(fma); +// TERNARY_FUNCTION(lerp); +// TERNARY_FUNCTION(select); +// TERNARY_FUNCTION(smoothstep); +// +// m.def("scatterAdd", [](const PyTensor& t, const PyTensor& t2) { +// Tensor::ScatterAdd(*t.Value(), T(t2), t.Indices()); +// }); +// +// m.def("scatterAddPrev", [](const PyTensor& t, const PyTensor& t2) { +// return PT(Tensor::ScatterAddPrev(*t.Value(), T(t2), t.Indices())); +// }); +// +// m.def("scatterMin", [](const PyTensor& t, const PyTensor& t2) { +// Tensor::ScatterMin(*t.Value(), T(t2), t.Indices()); +// }); +// +// m.def("scatterMax", [](const PyTensor& t, const PyTensor& t2) { +// Tensor::ScatterMax(*t.Value(), T(t2), t.Indices()); +// }); +// +// m.def("scatterOr", [](const PyTensor& t, const PyTensor& t2) { +// Tensor::ScatterOr(*t.Value(), T(t2), t.Indices()); +// }); +// +// m.def("scatterAnd", [](const PyTensor& t, const PyTensor& t2) { +// Tensor::ScatterAnd(*t.Value(), T(t2), t.Indices()); +// }); +// +// m.def("scatterXor", [](const PyTensor& t, const PyTensor& t2) { +// Tensor::ScatterXor(*t.Value(), T(t2), t.Indices()); +// }); +// +// m.def("buffer", [](py::list shape, TFDataFormat type) { +// return PT(Tensor::Memory(Reverse(TensorsFromList(shape)), type)); +// }, py::arg("shape"), py::arg("type") = TFTypeFloat32); +// m.def("buffer", [](std::vector shape, TFDataFormat type) { +// return PT(Tensor::Memory(Reverse(shape), type)); +// }, py::arg("shape"), py::arg("type") = TFTypeFloat32); +// +// m.def("local_buffer", [](int size, TFDataFormat type) { +// return PT(Tensor::LocalMemory(size, type)); +// }, py::arg("size"), py::arg("type") = TFTypeFloat32); +// m.def("group_buffer", [](int size, TFDataFormat type) { +// return PT(Tensor::GroupMemory(size, type)); +// }, py::arg("size"), py::arg("type") = TFTypeFloat32); +// m.def("group_barrier", []() { +// Tensor::GroupBarrier(); +// }); +// +// m.def("zeros", [](py::list shape, TFDataFormat type) { +// return PT(Tensor::Constant(0u, Reverse(TensorsFromList(shape)), type)); +// }, py::arg("shape"), py::arg("type") = TFTypeFloat32); +// +// m.def("const", [](float value, py::list shape) { +// return PT(Tensor::Constant(Reverse(TensorsFromList(shape)), value)); +// }); +// m.def("const", [](float value, std::vector shape) { +// return PT(Tensor::Constant(Reverse(shape), value)); +// }, py::arg("value"), py::arg("shape") = std::vector{}); +// +// m.def("const", [](int value, py::list shape) { +// return PT(Tensor::Constant(Reverse(TensorsFromList(shape)), value)); +// }); +// +// m.def("const", [](int value, std::vector shape) { +// return PT(Tensor::Constant(Reverse(shape), value)); +// }, py::arg("value"), py::arg("shape") = std::vector{}); +// +// m.def("input", [](std::vector shape, TFDataFormat type) { +// return PT(Tensor::Input(Reverse(shape), type)); +// }, py::arg("shape"), py::arg("type") = TFTypeFloat32); +// +// m.def("input", [](py::list shape, TFDataFormat type) { +// return PT(Tensor::Input(Reverse(TensorsFromList(shape)), type)); +// }, py::arg("shape"), py::arg("type") = TFTypeFloat32); +// +// m.def("index", [](int dim, py::list shape) { +// return PT(Tensor::Index(Reverse(TensorsFromList(shape)), dim)); +// }); +// +// m.def("hash", [](py::list shape, const PyTensor& seed) { +// return PT(Tensor::Hash(Reverse(TensorsFromList(shape)), T(seed))); +// }, py::arg("shape"), py::arg("seed")); +// +// m.def("random_value", [](py::list shape, const PyTensor& seed) { +// return PT(Tensor::Random(Reverse(TensorsFromList(shape)), T(seed))); +// }, py::arg("shape"), py::arg("seed")); +// +// m.def("element_index", [](py::list shape) { +// return PT(Tensor::ElementIndex(Reverse(TensorsFromList(shape)))); +// }, py::arg("shape")); +// +// m.def("flat_index", [](py::list shape, py::list indices) { +// Tensors shape_tensors = Reverse(TensorsFromList(shape)); +// Tensors index_tensors = Reverse(TensorsFromList(indices)); +// return PT(Tensor::FlatIndex(shape_tensors, index_tensors)); +// }); +// +// m.def("indices_from_flat_index", [](const PyTensor& index, py::list shape) { +// py::tuple indices = py::tuple(shape.size()); +// Tensors shape_tensors = Reverse(TensorsFromList(shape)); +// Tensors indices_tensors = Reverse(Tensor::IndicesFromFlatIndex(&T(index), shape_tensors)); +// for (int i = 0; i < indices_tensors.size(); i++) { +// indices[i] = PT(*indices_tensors[i]); +// } +// return indices; +// }); +// +// m.def("get_copy", [](const PyTensor& t) { return PT(*Tensor::GetCopy(T(t))); }); +// +// m.def("indices", [](py::list shape) { +// Tensors shape_tensors = Reverse(TensorsFromList(shape)); +// int dim = (int)shape_tensors.size(); +// py::tuple indices = py::tuple(shape_tensors.size()); +// for (int i = 0; i < shape_tensors.size(); i++) { +// auto t = PT(Tensor::Index(shape_tensors, dim - i - 1)); +// indices[i] = t; +// } +// return indices; +// }); +// +// m.def("indices", [](std::vector shape) { +// py::tuple indices = py::tuple(shape.size()); +// int dim = (int)shape.size(); +// for (int i = 0; i < shape.size(); i++) { +// auto t = PT(Tensor::Index(Reverse(shape), dim - i - 1)); +// indices[i] = t; +// } +// return indices; +// }); +// +// +// m.def("index_grid", [](py::list begin, py::list end) { +// Tensors begin_tensors = Reverse(TensorsFromList(begin)); +// Tensors end_tensors = Reverse(TensorsFromList(end)); +// Tensors index_grid = Reverse(Tensor::IndexGrid(begin_tensors, end_tensors)); +// +// py::tuple indices = py::tuple(begin.size()); +// for (int i = 0; i < index_grid.size(); i++) { +// indices[i] = PT(*index_grid[i]); +// } +// return indices; +// }); +// +// m.def("index_grid", [](py::list begin, py::list end, py::list step) { +// Tensors begin_tensors = Reverse(TensorsFromList(begin)); +// Tensors end_tensors = Reverse(TensorsFromList(end)); +// Tensors step_tensors = Reverse(TensorsFromList(step)); +// Tensors index_grid = Reverse(Tensor::IndexGrid(begin_tensors, end_tensors, step_tensors)); +// +// py::tuple indices = py::tuple(begin.size()); +// for (int i = 0; i < index_grid.size(); i++) { +// indices[i] = PT(*index_grid[i]); +// } +// return indices; +// }); +// +// m.def("reshape", [](const PyTensor& t, py::list shape, TFDataFormat type) { +// return PT(Tensor::Reshape(T(t), Reverse(TensorsFromList(shape)), type)); +// }, py::arg("t"), py::arg("shape"), py::arg("type") = TFTypeNone); +// +// m.def("assert_tensor", [](const PyTensor& t, py::list target_shape, TFDataFormat target_type) { +// return PT(Tensor::Assert(T(t), Reverse(TensorsFromList(target_shape)), target_type)); +// }); +// m.def("split_dim", [](const PyTensor& t, const int split_size, const int axis) { +// return PT(Tensor::SplitDim(T(t), split_size, -axis-1)); +// }, py::arg("t"), py::arg("split_size"), py::arg("axis") = -1); +// m.def("merge_dim", [](const PyTensor& t, const int axis, const PyTensor* target_size) { +// const Tensor* target_size_ptr = target_size ? &T(*target_size) : nullptr; +// return PT(Tensor::MergeDim(T(t), -axis-1, target_size_ptr)); +// }, py::arg("t"), py::arg("axis") = -1, py::arg("target_size") = nullptr); +// m.def("repeat", [](const PyTensor& t, const PyTensor& repeats, const int axis) { +// return PT(Tensor::Repeat(T(t), T(repeats), -axis-1)); +// }, py::arg("t"), py::arg("repeats"), py::arg("axis") = -1); +// +// //algorithm functions +// m.def("sum", [](const PyTensor& t, const int axis) { return PT(Tensor::Sum(T(t), -axis-1)); }, +// py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Sum the elements of the tensor along the axis"); +// +// m.def("norm", [](const PyTensor& t, const int axis) { return PT(Tensor::Norm(T(t), -axis-1)); }, +// py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Compute the norm of the tensor along the axis"); +// +// m.def("mean", [](const PyTensor& t, const int axis) { return PT(Tensor::Mean(T(t), -axis-1)); }, +// py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Compute the mean of the tensor along the axis"); +// +// m.def("min", [](const PyTensor& t, const int axis) { return PT(Tensor::Min(T(t), -axis-1)); }, +// py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Compute the min of the tensor along the axis"); +// +// m.def("max", [](const PyTensor& t, const int axis) { return PT(Tensor::Max(T(t), -axis-1)); }, +// py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Compute the max of the tensor along the axis"); +// +// m.def("any", [](const PyTensor& t, const int axis) { return PT(Tensor::Any(T(t), -axis-1)); }, +// py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Do an OR operation along the axis"); +// +// m.def("all", [](const PyTensor& t, const int axis) { return PT(Tensor::All(T(t), -axis-1)); }, +// py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Do an AND operation along the axis"); +// +// m.def("prefix_sum", [](const PyTensor& t, const int axis) { return PT(Tensor::PrefixSum(T(t), -axis-1)); }, +// py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Compute the prefix sum of the tensor along the axis"); +// +// m.def("reverse", [](const PyTensor& t, const int axis) { return PT(Tensor::Reverse(T(t), -axis-1)); }, +// py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Reverse the tensor along the axis"); +// +// m.def("transpose", [](const PyTensor& t, int dim1, int dim2) { +// return PT(Tensor::Transpose(T(t), -dim1-1, -dim2-1)); +// }, py::arg("t"), py::kw_only(), py::arg("dim1") = -2, py::arg("dim2") = -1, "Transpose the tensor"); +// +// m.def("unsqueeze", [](const PyTensor& t, int dim) { +// return PT(Tensor::Unsqueeze(T(t), -dim-1)); +// }, py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Unsqueeze the tensor"); +// +// m.def("squeeze", [](const PyTensor& t, int dim) { +// return PT(Tensor::Squeeze(T(t), -dim-1)); +// }, py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Squeeze the tensor"); +// +// m.def("dot", [](const PyTensor& t, const PyTensor& t2, int axis) { +// return PT(Tensor::Dot(T(t), T(t2), -axis-1)); +// }, py::arg("t"), py::arg("t2"), py::kw_only(), py::arg("axis") = -1, "Dot product of two tensors"); +// +// m.def("matmul", [](const PyTensor& t, const PyTensor& t2) { +// return PT(Tensor::Matmul(T(t), T(t2))); +// }, py::arg("t"), py::arg("t2"), "Matrix multiplication of two tensors"); +// +// m.def("region_begin", [](const std::string& name) { +// Tensor::BeginRegion(name); +// }, py::arg("name"), "Begin a debug region"); +// +// m.def("region_end", [](const std::string& name) { +// Tensor::EndRegion(name); +// }, py::arg("name"), "End a debug region"); +// +// m.def("register_custom_operation", [](const std::string& name, vector overloads, py::function impl, py::function vjp) { +// auto cpp_impl = [impl](Tensors& output, map inputs, const Tensor* tensor, vector axes) { +// py::list input_list; +// for (auto& [id, tensor] : inputs) { +// input_list.append(PT(*tensor)); +// } +// py::list output_list = impl(input_list, PT(*tensor), axes).cast(); +// for (int i = 0; i < output_list.size(); i++) { +// PyTensor* t = output_list[i].cast(); +// output.push_back(&t->Get()); +// } +// }; +// +// auto cpp_vjp = [vjp](map inputs, const Tensor* gradient, const Tensor* tensor) { +// py::list input_list; +// for (auto& [id, tensor] : inputs) { +// input_list.append(PT(*tensor)); +// } +// py::list output_list = vjp(input_list, PT(*gradient), PT(*tensor)).cast(); +// Tensors gradients; +// for (int i = 0; i < output_list.size(); i++) { +// PyTensor* t = output_list[i].cast(); +// gradients.push_back(&t->Get()); +// } +// return gradients; +// }; +// +// RegisterAlgorithmicPrimitive(name, overloads, cpp_impl, cpp_vjp); +// }, py::arg("name"), py::arg("overloads"), py::arg("impl"), py::arg("vjp"), "Register a custom operation"); +// +// m.def("custom", [](const std::string& name, py::list inputs, py::list shape) { +// Tensors input_tensors = TensorsFromList(inputs); +// Tensors shape_tensors = Reverse(TensorsFromList(shape)); +// return PT(Tensor::CustomOperation(name, input_tensors, shape_tensors)); +// }, py::arg("name"), py::arg("inputs"), py::arg("shape"), "Run custom operation"); +// +// m.def("custom", [](const std::string& name, py::list inputs) { +// Tensors input_tensors = TensorsFromList(inputs); +// Tensors shape_tensors = input_tensors[0]->GetShape(); +// return PT(Tensor::CustomOperation(name, input_tensors, shape_tensors)); +// }, py::arg("name"), py::arg("inputs"), "Run custom operation"); +// +// m.def("print_value", [](const std::string& name, const PyTensor& t) { +// Tensor::PrintValue(name, T(t)); +// }, py::arg("name"), py::arg("t"), "Print the value of the tensor"); +// +// m.def("assert_value", [](const std::string& name, const PyTensor& t) { +// Tensor::AssertValue(name, T(t)); +// }, py::arg("name"), py::arg("t"), "Assert the value of the tensor"); +// } +// +// } // namespace TensorFrost diff --git a/TensorFrost/src/Definitions/TensorMemory.cpp b/TensorFrost/src/Definitions/TensorMemory.cpp new file mode 100644 index 00000000..f6186b49 --- /dev/null +++ b/TensorFrost/src/Definitions/TensorMemory.cpp @@ -0,0 +1,72 @@ +// #include +// #include +// +// #include +// #include +// +// namespace TensorFrost { +// +// void TensorMemoryDefinition(py::module& m, +// py::class_& py_tensor_mem) { +// //define constructors from numpy arrays +// py_tensor_mem.def(py::init([](py::array arr) { +// return PyTensorMemory(arr); +// }), "Create a TensorMemory from a numpy array", py::return_value_policy::take_ownership); +// +// // "constructor" +// m.def( +// "tensor", +// [](const std::vector& shape, TFDataFormat type) { +// return PyTensorMemory(shape, type); +// },"Create a TensorMemory with the given shape", py::return_value_policy::take_ownership); +// +// // "constructor" from numpy array +// m.def( +// "tensor", +// [](py::array arr) { +// return new PyTensorMemory(arr); +// }, +// "Create a TensorMemory from a numpy array", py::return_value_policy::take_ownership); +// +// // properties +// py_tensor_mem.def_property_readonly("shape", [](const PyTensorMemory& t) { +// vector shape = t.Shape(); +// return py::cast(shape); +// }); +// +// py_tensor_mem.def_property_readonly("type", [](const PyTensorMemory& t) { +// return t.GetFormat(); +// }); +// +// py_tensor_mem.def_property_readonly("size", [](const PyTensorMemory& t) { +// return GetSize(t.tensor_); +// }); +// +// // to numpy array +// py_tensor_mem.def_property_readonly( +// "numpy", +// [](const PyTensorMemory& t) +// -> std::variant, py::array_t, +// py::array_t, py::array_t> { +// if (t.GetFormat() == TFTypeFloat32) { +// return t.ToPyArray(); +// } else if (t.GetFormat() == TFTypeInt32) { +// return t.ToPyArray(); +// } else if (t.GetFormat() == TFTypeUint32) { +// return t.ToPyArray(); +// } else if (t.GetFormat() == TFTypeBool32) { +// return t.ToPyArray(); +// } else { +// throw std::runtime_error("Unsupported data type for numpy conversion"); +// } +// }, +// "Readback data from tensor memory to a numpy array", py::return_value_policy::take_ownership); +// +// m.def("allocated_memory", []() { return global_memory_manager->GetAllocatedSize(); }, +// "Get the amount of memory currently used by the memory manager"); +// +// m.def("unused_allocated_memory", []() { return global_memory_manager->GetUnusedAllocatedSize(); }, +// "Get the amount of memory currently allocated but not used by the memory manager"); +// } +// +// } // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/src/Definitions/TensorProgram.cpp b/TensorFrost/src/Definitions/TensorProgram.cpp new file mode 100644 index 00000000..dd1cfcf5 --- /dev/null +++ b/TensorFrost/src/Definitions/TensorProgram.cpp @@ -0,0 +1,199 @@ +// #include +// #include +// +// #include +// #include +// #include +// +// namespace TensorFrost { +// +// void TensorProgramDefinition(py::module& m, +// py::class_& tensor_program) { +// m.def( +// "compile", +// [](const py::function& py_evaluate) { +// // Extract the name of the Python function +// std::string func_name = +// py_evaluate.attr("__name__").cast(); +// +// vector inputs = GetFunctionArguments(py_evaluate); +// vector arg_names; +// vector arg_props; +// for (auto arg : inputs) { +// arg_names.push_back(std::get<0>(arg)); +// py::object arg_prop = std::get<1>(arg); +// py::object arg_default = std::get<2>(arg); +// if (py::isinstance(arg_prop)) { +// PyTensorArg arg_tensor = arg_prop.cast(); +// arg_props.push_back(arg_tensor); +// } else { +// throw std::runtime_error("Unsupported input type " + std::string(py::str(arg_prop))); +// } +// } +// +// TensorProgram& program = *new TensorProgram( +// [py_evaluate, arg_names, arg_props]() -> Tensors { +// py::gil_scoped_acquire acquire; +// std::vector args; +// //create inputs from the arguments +// for (size_t i = 0; i < arg_names.size(); i++) { +// Tensor& input = Tensor::Input(arg_props[i].shape, arg_props[i].type); +// input.SetDebugName(arg_names[i]); +// PyTensor* py_tensor = new PyTensor(&input); +// args.push_back(py_tensor); +// } +// //convert to py::args +// py::args py_args = py::cast(args); +// py::object result = py_evaluate(*py_args); +// Tensors outputs; +// //if the result is a single tensor +// if (py::isinstance(result)) { +// outputs.push_back(&py::cast(result).Get()); +// } else { +// auto py_outputs = py::cast>(result); +// for (PyTensor output : py_outputs) { +// outputs.push_back(&output.Get()); +// } +// } +// return outputs; +// }, +// func_name); +// +// py::print(program.PrintProperties()); +// return &program; +// }, +// "Compile a TensorProgram from a python function"); +// +// tensor_program.def( +// "__call__", +// [](TensorProgram& program, py::args py_inputs) -> std::variant { +// vector inputs_props; +// vector temp_numpy_tensors; +// for (auto arg : py_inputs) { +// if (py::isinstance(arg)) { //if just tensor memory +// PyTensorMemory* mem = &arg.cast(); +// inputs_props.push_back(arg.cast()); +// } else if (py::isinstance(arg)) { //if module then add its parameters +// Module* module = &arg.cast(); +// py::list params = module->parameters(); +// for (auto param : params) { +// inputs_props.push_back(param.cast()); +// } +// } else if (py::isinstance(arg)) { //if numpy array then create pytensormemory from it and add it +// py::array arr = arg.cast(); +// PyTensorMemory* temp_tensor = new PyTensorMemory(arr); +// inputs_props.push_back(py::cast(temp_tensor, py::return_value_policy::take_ownership)); +// temp_numpy_tensors.push_back(temp_tensor->tensor_); +// } else if (py::isinstance(arg)) { //if list then convert to py::array then create pytensormemory from it and add it +// py::array arr = ListToArray(arg.cast()); +// PyTensorMemory* temp_tensor = new PyTensorMemory(arr); +// inputs_props.push_back(py::cast(temp_tensor, py::return_value_policy::take_ownership)); +// temp_numpy_tensors.push_back(temp_tensor->tensor_); +// } else { +// throw std::runtime_error("Unsupported input type " + std::string(py::str(arg))); +// } +// } +// +// vector inputs; +// for (auto input : inputs_props) { +// PyTensorMemory* mem = input.cast(); +// inputs.push_back(mem->tensor_); +// } +// vector outputs = program.Evaluate(inputs); +// +// //remove temporary tensors if they are not in the outputs +// for (TFTensor* temp_tensor : temp_numpy_tensors) { +// bool found = false; +// for (TFTensor* output : outputs) { +// if (temp_tensor->buffer == output->buffer) { +// found = true; +// break; +// } +// } +// if (!found) { +// global_memory_manager->DeallocateTensor(*temp_tensor); +// } +// } +// +// vector output_tensors; +// for (size_t i = 0; i < outputs.size(); i++) { +// //if any of the outputs are also inputs, then replace them with the input tensors +// TFTensor* out = outputs[i]; +// bool is_input = false; +// for (size_t j = 0; j < inputs_props.size(); j++) { +// PyTensorMemory* in = inputs_props[j].cast(); +// if (out->buffer == in->tensor_->buffer) { +// output_tensors.push_back(inputs_props[j]); +// is_input = true; +// break; +// } +// } +// if (is_input) { +// continue; +// } +// //otherwise create a new tensor memory +// output_tensors.push_back(py::cast(new PyTensorMemory(outputs[i]), py::return_value_policy::take_ownership)); +// } +// +// //if there is only one output, return the tensor memory +// if (outputs.size() == 1) { +// return output_tensors[0]; +// } else { +// //convert to py::tuple of PyTensorMemory* +// py::tuple py_outputs = py::tuple(outputs.size()); +// for (size_t i = 0; i < outputs.size(); i++) { +// py_outputs[i] = output_tensors[i]; +// } +// return py_outputs; +// } +// }, +// "Evaluate the TensorProgram with the given inputs"); +// +// tensor_program.def( +// "list_operations", +// [](TensorProgram& program, bool compact) { +// std::string listing = "List of operations:\n"; +// listing += GetOperationListing(program.ir, compact); +// return py::str(listing); +// }, +// py::arg("compact") = true); +// +// tensor_program.def("compiled_code", [](TensorProgram& program) { +// string code = program.program->generated_code_; +// return py::str(code); +// }); +// +// tensor_program.def("get_kernels", [](TensorProgram& program) { +// vector kernel_source; +// for (auto& kernel : program.program->kernels_) { +// kernel_source.push_back(kernel.full_generated_code_); +// } +// return kernel_source; +// }); +// +// tensor_program.def("get_main_function", [](TensorProgram& program) { +// return program.program->main_function_; +// }); +// +// tensor_program.def("get_last_execution_time", [](TensorProgram& program) { +// return program.program->last_execution_time; +// }); +// +// m.def("get_all_generated_main_functions", []() { +// return global_kernel_manager->GetAllMainFunctions(); +// }); +// +// m.def("get_all_generated_kernels", []() { +// return global_kernel_manager->GetAllKernels(); +// }); +// +// m.def("get_cpp_header", []() { +// return GetCPPHeader(); +// }); +// +// m.def("get_cpp_implementation", []() { +// return GetCPPImplementation(); +// }); +// } +// +// } // namespace TensorFrost diff --git a/TensorFrost/src/Definitions/TensorScope.cpp b/TensorFrost/src/Definitions/TensorScope.cpp new file mode 100644 index 00000000..07bac335 --- /dev/null +++ b/TensorFrost/src/Definitions/TensorScope.cpp @@ -0,0 +1,118 @@ +// #include +// #include +// +// #include +// +// namespace TensorFrost { +// +// void ScopeDefinitions(py::module& m, py::class_& py_tensor) { +// m.def( +// "loop", +// [](const py::function& body, const PyTensor& begin, const PyTensor& end, +// const PyTensor& step) { +// // wrap the function to convert the PyTensor to Tensor +// std::function f2 = [&body](const Tensor& t) { +// py::gil_scoped_acquire acquire; +// body(PT(t)); +// }; +// +// Tensor::Loop(T(begin), T(end), T(step), f2); +// }, +// py::arg("begin") = 0, py::arg("end"), py::arg("step") = 1, +// py::arg("body")); +// +// m.def( +// "if_cond", +// [](const PyTensor& condition, const py::function& true_body) { +// std::function f = [&true_body]() { +// py::gil_scoped_acquire acquire; +// true_body(); +// }; +// Tensor::If(T(condition), f); +// }, +// py::arg("condition"), py::arg("true_body")); +// +// m.def( +// "if_cond", +// [](const PyTensor& condition, const py::function& true_body, +// const py::function& false_body) { +// std::function f1 = [&true_body]() { +// py::gil_scoped_acquire acquire; +// true_body(); +// }; +// std::function f2 = [&false_body]() { +// py::gil_scoped_acquire acquire; +// false_body(); +// }; +// Tensor::If(T(condition), f1, f2); +// }, +// py::arg("condition"), py::arg("true_body"), py::arg("false_body")); +// +// m.def("break_loop", []() { Tensor::Break(); }); +// m.def("continue_loop", []() { Tensor::Continue(); }); +// +// m.def( +// "kernel", +// [](py::list shape, const py::function& body) { +// // wrap the function to convert the PyTensor to Tensor +// std::function f2 = +// [&body](const Tensors& tensors) { +// py::gil_scoped_acquire acquire; +// PyTensors py_tensors = PyTensorsFromTensors(tensors); +// body(py_tensors); +// }; +// +// Tensor::Kernel(Reverse(TensorsFromList(shape)), f2); +// }, +// py::arg("shape"), py::arg("body")); +// +// // m.def( +// // "vmap", +// // [](py::list inputs, py::list shape, const py::function& func) { +// // std::function f = [&func]() { +// // py::gil_scoped_acquire acquire; +// // func(); +// // }; +// // Tensor::If(T(condition), f); +// // }, +// // py::arg("condition"), py::arg("true_body")); +// +// py_tensor.def("__enter__", &PyTensor::__enter__); +// py_tensor.def("__exit__", &PyTensor::__exit__); +// +// //loop scope +// m.def("loop", +// [](const PyTensor& begin, const PyTensor& end, const PyTensor& step) { +// Tensor& for_loop = Tensor::Loop(T(begin), T(end), T(step)); +// return PT(for_loop); +// }); +// +// m.def("loop", +// [](const PyTensor& begin, const PyTensor& end) { +// Tensor& for_loop = Tensor::Loop(T(begin), T(end), T(PyTensor(1))); +// return PT(for_loop); +// }); +// +// m.def("loop", +// [](const PyTensor& end) { +// Tensor& for_loop = Tensor::Loop(T(PyTensor(0)), T(end), T(PyTensor(1))); +// return PT(for_loop); +// }); +// +// //if scope +// m.def("if_cond", +// [](const PyTensor& condition) { +// Tensor& if_cond = Tensor::If(T(condition)); +// return PT(if_cond); +// }); +// +// //kernel scope +// m.def("kernel", +// [](py::list shape, vector group_size) { +// Tensors shape_tensors = Reverse(TensorsFromList(shape)); +// Tensor& kernel = Tensor::Kernel(shape_tensors, group_size); +// return PT(kernel); +// }, py::arg("shape"), py::arg("group_size") = vector()); +// } +// +// } // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/src/Definitions/WindowUtils.cpp b/TensorFrost/src/Definitions/WindowUtils.cpp new file mode 100644 index 00000000..5bbf4647 --- /dev/null +++ b/TensorFrost/src/Definitions/WindowUtils.cpp @@ -0,0 +1,148 @@ +// #include +// #include +// +// #include +// #include +// +// #include "Backend/RenderDoc.h" +// +// namespace TensorFrost { +// +// void WindowDefinitions(py::module& m) { +// py::module window = m.def_submodule("window", "Window functions"); +// +// window.def( +// "show", +// [](int width, int height, string title) { +// ShowWindow(width, height, title.c_str()); +// }, +// "Show the memory manager window"); +// +// window.def( +// "hide", []() { HideWindow(); }, "Hide the memory manager window"); +// +// window.def( +// "render_frame", [](const PyTensorMemory& t) { RenderFrame(t.tensor_); }, +// "Render a frame from the tensor memory"); +// +// window.def("render_frame", []() { RenderFrame(nullptr); }, +// "Render an empty frame"); +// +// window.def( +// "should_close", []() { return WindowShouldClose(); }, +// "Check if the window should close"); +// +// window.def( +// "get_mouse_position", []() { return GetMousePosition(); }, +// "Get the current mouse position"); +// +// window.def( +// "get_size", []() { return GetWindowSize(); }, +// "Get the current window size"); +// +// window.def( +// "is_mouse_button_pressed", +// [](int button) { return IsMouseButtonPressed(button); }, +// "Check if a mouse button is pressed"); +// +// window.def( +// "is_key_pressed", [](int key) { return IsKeyPressed(key); }, +// "Check if a key is pressed"); +// +// window.attr("MOUSE_BUTTON_0") = GLFW_MOUSE_BUTTON_1; +// window.attr("MOUSE_BUTTON_1") = GLFW_MOUSE_BUTTON_2; +// window.attr("MOUSE_BUTTON_2") = GLFW_MOUSE_BUTTON_3; +// +// window.attr("KEY_SPACE") = GLFW_KEY_SPACE; +// window.attr("KEY_APOSTROPHE") = GLFW_KEY_APOSTROPHE; +// window.attr("KEY_COMMA") = GLFW_KEY_COMMA; +// window.attr("KEY_MINUS") = GLFW_KEY_MINUS; +// window.attr("KEY_PERIOD") = GLFW_KEY_PERIOD; +// +// window.attr("KEY_A") = GLFW_KEY_A; +// window.attr("KEY_B") = GLFW_KEY_B; +// window.attr("KEY_C") = GLFW_KEY_C; +// window.attr("KEY_D") = GLFW_KEY_D; +// window.attr("KEY_E") = GLFW_KEY_E; +// window.attr("KEY_F") = GLFW_KEY_F; +// window.attr("KEY_G") = GLFW_KEY_G; +// window.attr("KEY_H") = GLFW_KEY_H; +// window.attr("KEY_I") = GLFW_KEY_I; +// window.attr("KEY_J") = GLFW_KEY_J; +// window.attr("KEY_K") = GLFW_KEY_K; +// window.attr("KEY_L") = GLFW_KEY_L; +// window.attr("KEY_M") = GLFW_KEY_M; +// window.attr("KEY_N") = GLFW_KEY_N; +// window.attr("KEY_O") = GLFW_KEY_O; +// window.attr("KEY_P") = GLFW_KEY_P; +// window.attr("KEY_Q") = GLFW_KEY_Q; +// window.attr("KEY_R") = GLFW_KEY_R; +// window.attr("KEY_S") = GLFW_KEY_S; +// window.attr("KEY_T") = GLFW_KEY_T; +// window.attr("KEY_U") = GLFW_KEY_U; +// window.attr("KEY_V") = GLFW_KEY_V; +// window.attr("KEY_W") = GLFW_KEY_W; +// window.attr("KEY_X") = GLFW_KEY_X; +// window.attr("KEY_Y") = GLFW_KEY_Y; +// window.attr("KEY_Z") = GLFW_KEY_Z; +// +// py::module imgui = m.def_submodule("imgui", "ImGui functions"); +// +// imgui.def( +// "begin", [](string name) { ImGuiBegin(name); }, +// "Begin a new ImGui window"); +// +// imgui.def("end", []() { ImGuiEnd(); }, "End the current ImGui window"); +// +// imgui.def( +// "text", [](string text) { ImGuiText(text); }, +// "Add text to the current ImGui window"); +// +// imgui.def("slider", +// [](string text, int* value, int min, int max) { +// ImGuiSlider(text, value, min, max); +// return *value; +// }, "Add a slider to the current ImGui window"); +// +// imgui.def("slider", [](string text, float* value, float min, float max) { +// ImGuiSlider(text, value, min, max); +// return *value; +// }, "Add a slider to the current ImGui window"); +// +// imgui.def("button", [](string text) { return ImGuiButton(text); }, +// "Add a button to the current ImGui window"); +// +// imgui.def("checkbox", [](string text, bool* value) { +// ImGuiCheckbox(text, value); +// return *value; +// }, "Add a checkbox to the current ImGui window"); +// +// imgui.def("plotlines", [](string label, py::array_t values, int values_offset, string overlay_text, float scale_min, float scale_max, py::tuple graph_size, int stride) { +// ImGuiPlotLines(label.c_str(), values.data(), (int)values.size(), values_offset, overlay_text.c_str(), scale_min, scale_max, ImVec2(graph_size[0].cast(), graph_size[1].cast()), stride); +// }, py::arg("label"), py::arg("values"), py::arg("values_offset") = 0, py::arg("overlay_text") = "", py::arg("scale_min") = FLT_MAX, py::arg("scale_max") = FLT_MAX, py::arg("graph_size") = py::make_tuple(0.0f, 0.0f), py::arg("stride") = sizeof(float)); +// +// imgui.def("scale_all_sizes", [](float scale) { ImGuiScaleAllSizes(scale); }, +// "Scale all ImGui sizes by a factor"); +// +// imgui.def("add_background_text", [](string text, py::tuple pos, py::tuple color) { +// ImGuiAddBackgroundText(text, ImVec2(pos[0].cast(), pos[1].cast()), ImVec4(color[0].cast(), color[1].cast(), color[2].cast(), color[3].cast())); +// }, py::arg("text"), py::arg("pos"), py::arg("color")); +// +// imgui.def("color_picker3", [](string text, py::array_t color) { +// ImGuiColorPicker3(text, color.mutable_data()); +// return color; +// }, py::arg("text"), py::arg("color")); +// +// imgui.def("color_picker4", [](string text, py::array_t color) { +// ImGuiColorPicker4(text, color.mutable_data()); +// return color; +// }, py::arg("text"), py::arg("color")); +// +// //renderdoc +// m.def("renderdoc_start_capture", []() { StartRenderDocCapture(); }, +// "Start a RenderDoc capture"); +// m.def("renderdoc_end_capture", []() { EndRenderDocCapture(); }, +// "End a RenderDoc capture"); +// } +// +// } // namespace TensorFrost \ No newline at end of file From bd53f728050ef583d85c47301db8212f2b5bb329 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz Date: Sun, 1 Jun 2025 05:27:50 +0200 Subject: [PATCH 04/44] Some tests --- TensorFrost/PybindModule.cpp | 9 +++++++++ TensorFrost/include/Compiler/Overloads.h | 13 ++++++------- TensorFrost/include/TensorFrost.h | 7 +++++++ 3 files changed, 22 insertions(+), 7 deletions(-) create mode 100644 TensorFrost/include/TensorFrost.h diff --git a/TensorFrost/PybindModule.cpp b/TensorFrost/PybindModule.cpp index 4de7cf13..258448bd 100644 --- a/TensorFrost/PybindModule.cpp +++ b/TensorFrost/PybindModule.cpp @@ -9,6 +9,8 @@ #include #include +#include "TensorFrost.h" + namespace py = pybind11; namespace TensorFrost { @@ -129,6 +131,13 @@ PYBIND11_MODULE(TensorFrost, m) { #else py::print("TensorFrost module loaded in debug mode! Expect slow performance."); #endif + + // TEST CODE + StartExecutionContext(); + Op& a = constant(5); + Op& b = constant(10); + Op& c = a + b; + py::print("Created operation: ", c.opcode); } } // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/include/Compiler/Overloads.h b/TensorFrost/include/Compiler/Overloads.h index bdb94c3f..bbc84d7d 100644 --- a/TensorFrost/include/Compiler/Overloads.h +++ b/TensorFrost/include/Compiler/Overloads.h @@ -5,13 +5,12 @@ namespace TensorFrost { Op& make_op(std::string op, std::vector mem, std::vector ids, std::vector args, std::vector shape); -template -Op& func_op(std::string op, const Args&... args) { - std::vector mem; - std::vector ids; - std::vector args_vec = {&args...}; - std::vector shape; - return make_op(op, mem, ids, args_vec, shape); +template +Op &func_op(const std::string &name, const Ts &... args) { + std::vector v; + v.reserve(sizeof...(args)); + (v.push_back(const_cast(&args)), ...); + return make_op(name, {}, {}, v, {}); } Op& constant(int value); diff --git a/TensorFrost/include/TensorFrost.h b/TensorFrost/include/TensorFrost.h new file mode 100644 index 00000000..140eb362 --- /dev/null +++ b/TensorFrost/include/TensorFrost.h @@ -0,0 +1,7 @@ +#pragma once + +#include "Compiler/Operation.h" +#include "Compiler/ExecutionContext.h" +#include "Compiler/OperationBlocks.h" +#include "Compiler/OperationArguments.h" +#include "Compiler/Overloads.h" \ No newline at end of file From 60268b1388c788df9780b0f82bc8390feb55ef53 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz Date: Sun, 1 Jun 2025 05:52:57 +0200 Subject: [PATCH 05/44] Working super minimal test --- Python/TensorFrost/clipping.py | 10 +- Python/TensorFrost/default.py | 28 +- Python/TensorFrost/optimizers.py | 438 +++++++++--------- Python/TensorFrost/random.py | 90 ++-- Python/TensorFrost/regularizers.py | 10 +- Python/TensorFrost/sort.py | 374 +++++++-------- TensorFrost/include/Compiler/Operation.h | 5 +- .../include/Compiler/OperationBlocks.h | 3 + TensorFrost/src/Compiler/ExecutionContext.cpp | 2 +- TensorFrost/src/Compiler/Operation.cpp | 21 +- TensorFrost/src/Compiler/OperationBlocks.cpp | 2 + .../src/Compiler/OperationRegistry.cpp | 2 + examples/debug.py | 44 -- 13 files changed, 487 insertions(+), 542 deletions(-) diff --git a/Python/TensorFrost/clipping.py b/Python/TensorFrost/clipping.py index b343917e..7c79851d 100644 --- a/Python/TensorFrost/clipping.py +++ b/Python/TensorFrost/clipping.py @@ -1,5 +1,5 @@ -from .optimizers import * - -clamp = ModuleOptimizer.ClippingType.Clamp -norm = ModuleOptimizer.ClippingType.Norm -none = ModuleOptimizer.ClippingType.None_ \ No newline at end of file +# from .optimizers import * +# +# clamp = ModuleOptimizer.ClippingType.Clamp +# norm = ModuleOptimizer.ClippingType.Norm +# none = ModuleOptimizer.ClippingType.None_ \ No newline at end of file diff --git a/Python/TensorFrost/default.py b/Python/TensorFrost/default.py index 41853504..a015108c 100644 --- a/Python/TensorFrost/default.py +++ b/Python/TensorFrost/default.py @@ -1,14 +1,14 @@ -from . import TensorFrost as tf - -def zeros_like(tensor): - return tf.zeros(tensor.shape, tensor.type) - -def eye(n): - i, j = tf.indices([n, n]) - return tf.select(i == j, 1.0, 0.0) - -def eye_like(tensor): - return eye(tensor.shape[0]) - -def ones_like(tensor): - return tf.ones(tensor.shape, tensor.type) \ No newline at end of file +# from . import TensorFrost as tf +# +# def zeros_like(tensor): +# return tf.zeros(tensor.shape, tensor.type) +# +# def eye(n): +# i, j = tf.indices([n, n]) +# return tf.select(i == j, 1.0, 0.0) +# +# def eye_like(tensor): +# return eye(tensor.shape[0]) +# +# def ones_like(tensor): +# return tf.ones(tensor.shape, tensor.type) \ No newline at end of file diff --git a/Python/TensorFrost/optimizers.py b/Python/TensorFrost/optimizers.py index 731ae81b..21b07ce4 100644 --- a/Python/TensorFrost/optimizers.py +++ b/Python/TensorFrost/optimizers.py @@ -1,219 +1,219 @@ -from . import TensorFrost as tf - -class ModuleOptimizer(tf.Module): - class OptimizerType: - ADAM = 0 - SGD = 1 - RMSProp = 2 - - class RegularizerType: - None_ = 0 - L1 = 1 - L2 = 2 - - class ClippingType: - Clamp = 0 - Norm = 1 - None_ = 2 - - def __init__(self, optimizer_type, regularizer_type, net, params): - super().__init__() - self.optimizer_type = optimizer_type - self.regularizer_type = regularizer_type - self.clipping_type = self.ClippingType.Clamp - self.epsilon = 1e-8 - - # Set passed parameters as attributes - self.net = net - for k, v in params.items(): - setattr(self, k, v) - - # Initialize t - t = tf.Parameter([1], tf.float32, False) # mimic Parameter({1}, TFType::Float, false) - self.t = t - - self.initializeOptimizer(net) - - def set_clipping_type(self, ctype): - self.clipping_type = ctype - - def initializeOptimizer(self, net): - net_params = net.parameters() - requires_grads = net.requires_grads_list() - - if self.optimizer_type == self.OptimizerType.ADAM: - self.initializeParameterArray("m", net_params, requires_grads) - self.initializeParameterArray("v", net_params, requires_grads) - elif self.optimizer_type == self.OptimizerType.SGD: - # No additional parameters needed - pass - elif self.optimizer_type == self.OptimizerType.RMSProp: - self.initializeParameterArray("v", net_params, requires_grads) - - def initializeParameterArray(self, name, net_params, requires_grads): - arr = tf.ParameterArray() - - for i, param in enumerate(net_params): - if not requires_grads[i]: - continue - - new_param = tf.Parameter(param.shape, tf.float32, False) - arr[i] = new_param - - setattr(self, name, arr) - - def assert_parameters(self): - net_params = self.net.parameters() - requires_grads = self.net.requires_grads_list() - self.assertParameterArray("m", net_params, requires_grads) - self.assertParameterArray("v", net_params, requires_grads) - - def gradient_norm(self, grad): - # sum of squares - g = grad * grad - shape = grad.shape - num_dims = len(shape) - for i in range(num_dims): - g = tf.sum(g) - return tf.sqrt(g) - - def assertParameterArray(self, name, net_params, requires_grads): - if hasattr(self, name): - arr = getattr(self, name) - for i, param in enumerate(net_params): - if not requires_grads[i]: - continue - arr_item = arr[i] - arr_item = tf.assert_tensor(arr_item, param.shape, param.type) - arr[i] = arr_item - - def step(self, *args): - # Overloaded step: - # step(X, Y) or step(loss) - if len(args) == 2: - X, Y = args - loss = self.net.loss(X, Y) - self._step(loss) - return loss - elif len(args) == 1: - (loss,) = args - self._step(loss) - else: - raise ValueError("Invalid arguments to step") - - def _step(self, loss): - # Increment t by 1 - self.t = self.t + 1.0 - - net = self.net - net_params = net.parameters() - requires_grads = net.requires_grads_list() - - learning_rate = self.learning_rate - grad_clip = self.grad_clip - has_clip = isinstance(grad_clip, float) and grad_clip > 0.0 - - for i, param in enumerate(net_params): - if not requires_grads[i]: - continue - - grad = tf.grad(loss, param) - if has_clip: - if self.clipping_type == self.ClippingType.Clamp: - grad = tf.clamp(grad, -grad_clip, grad_clip) - elif self.clipping_type == self.ClippingType.Norm: - grad_norm = tf.max(1e-6, self.gradient_norm(grad)) - grad = grad * tf.min(1.0, grad_clip / grad_norm) - - if self.optimizer_type == self.OptimizerType.ADAM: - update = self.adam_update(i, param, grad, self.t, learning_rate) - elif self.optimizer_type == self.OptimizerType.SGD: - update = self.sgd_update(param, grad, learning_rate) - elif self.optimizer_type == self.OptimizerType.RMSProp: - update = self.rmsprop_update(i, param, grad, learning_rate) - else: - raise RuntimeError("Unknown optimizer type") - - # Apply regularization if needed - if self.regularizer_type == self.RegularizerType.L1: - param = param - learning_rate * self.reg * tf.sign(param) - elif self.regularizer_type == self.RegularizerType.L2: - param = param - learning_rate * self.reg * param - - # Update parameter with computed update - param = param - update - net_params[i] = param - - net.update_parameters(net_params) - - def adam_update(self, i, param, grad, t, learning_rate): - beta1 = tf.float(self.beta1) - beta2 = tf.float(self.beta2) - - m = self.m[i] - v = self.v[i] - - m = tf.lerp(grad, m, beta1) - v = tf.lerp(grad * grad, v, beta2) - - # t is a Parameter with shape [1]; get the scalar - t_val = self.t[0] - mhat = m / (1.0 - tf.pow(beta1, t_val)) - vhat = v / (1.0 - tf.pow(beta2, t_val)) - - self.m[i] = m - self.v[i] = v - - return learning_rate * mhat / (tf.sqrt(vhat) + self.epsilon) - - def sgd_update(self, param, grad, learning_rate): - return learning_rate * grad - - def rmsprop_update(self, i, param, grad, learning_rate): - decay = tf.float(self.decay) - - v = self.v[i] - v = tf.lerp(grad * grad, v, decay) - self.v[i] = v - - return (grad * learning_rate) / (tf.sqrt(v) + self.epsilon) - - -def adam(net, reg_type=ModuleOptimizer.RegularizerType.None_, learning_rate=0.001, beta1=0.9, beta2=0.999, clip=0.0, reg=0.0): - return ModuleOptimizer( - ModuleOptimizer.OptimizerType.ADAM, - reg_type, - net, - { - "learning_rate": learning_rate, - "beta1": beta1, - "beta2": beta2, - "grad_clip": clip, - "reg": reg, - } - ) - -def sgd(net, reg_type=ModuleOptimizer.RegularizerType.None_, learning_rate=0.001, clip=0.0, reg=0.0): - return ModuleOptimizer( - ModuleOptimizer.OptimizerType.SGD, - reg_type, - net, - { - "learning_rate": learning_rate, - "grad_clip": clip, - "reg": reg, - } - ) - -def rmsprop(net, reg_type=ModuleOptimizer.RegularizerType.None_, learning_rate=0.001, decay=0.9, clip=0.0, reg=0.0): - return ModuleOptimizer( - ModuleOptimizer.OptimizerType.RMSProp, - reg_type, - net, - { - "learning_rate": learning_rate, - "decay": decay, - "grad_clip": clip, - "reg": reg, - } - ) \ No newline at end of file +# from . import TensorFrost as tf +# +# class ModuleOptimizer(tf.Module): +# class OptimizerType: +# ADAM = 0 +# SGD = 1 +# RMSProp = 2 +# +# class RegularizerType: +# None_ = 0 +# L1 = 1 +# L2 = 2 +# +# class ClippingType: +# Clamp = 0 +# Norm = 1 +# None_ = 2 +# +# def __init__(self, optimizer_type, regularizer_type, net, params): +# super().__init__() +# self.optimizer_type = optimizer_type +# self.regularizer_type = regularizer_type +# self.clipping_type = self.ClippingType.Clamp +# self.epsilon = 1e-8 +# +# # Set passed parameters as attributes +# self.net = net +# for k, v in params.items(): +# setattr(self, k, v) +# +# # Initialize t +# t = tf.Parameter([1], tf.float32, False) # mimic Parameter({1}, TFType::Float, false) +# self.t = t +# +# self.initializeOptimizer(net) +# +# def set_clipping_type(self, ctype): +# self.clipping_type = ctype +# +# def initializeOptimizer(self, net): +# net_params = net.parameters() +# requires_grads = net.requires_grads_list() +# +# if self.optimizer_type == self.OptimizerType.ADAM: +# self.initializeParameterArray("m", net_params, requires_grads) +# self.initializeParameterArray("v", net_params, requires_grads) +# elif self.optimizer_type == self.OptimizerType.SGD: +# # No additional parameters needed +# pass +# elif self.optimizer_type == self.OptimizerType.RMSProp: +# self.initializeParameterArray("v", net_params, requires_grads) +# +# def initializeParameterArray(self, name, net_params, requires_grads): +# arr = tf.ParameterArray() +# +# for i, param in enumerate(net_params): +# if not requires_grads[i]: +# continue +# +# new_param = tf.Parameter(param.shape, tf.float32, False) +# arr[i] = new_param +# +# setattr(self, name, arr) +# +# def assert_parameters(self): +# net_params = self.net.parameters() +# requires_grads = self.net.requires_grads_list() +# self.assertParameterArray("m", net_params, requires_grads) +# self.assertParameterArray("v", net_params, requires_grads) +# +# def gradient_norm(self, grad): +# # sum of squares +# g = grad * grad +# shape = grad.shape +# num_dims = len(shape) +# for i in range(num_dims): +# g = tf.sum(g) +# return tf.sqrt(g) +# +# def assertParameterArray(self, name, net_params, requires_grads): +# if hasattr(self, name): +# arr = getattr(self, name) +# for i, param in enumerate(net_params): +# if not requires_grads[i]: +# continue +# arr_item = arr[i] +# arr_item = tf.assert_tensor(arr_item, param.shape, param.type) +# arr[i] = arr_item +# +# def step(self, *args): +# # Overloaded step: +# # step(X, Y) or step(loss) +# if len(args) == 2: +# X, Y = args +# loss = self.net.loss(X, Y) +# self._step(loss) +# return loss +# elif len(args) == 1: +# (loss,) = args +# self._step(loss) +# else: +# raise ValueError("Invalid arguments to step") +# +# def _step(self, loss): +# # Increment t by 1 +# self.t = self.t + 1.0 +# +# net = self.net +# net_params = net.parameters() +# requires_grads = net.requires_grads_list() +# +# learning_rate = self.learning_rate +# grad_clip = self.grad_clip +# has_clip = isinstance(grad_clip, float) and grad_clip > 0.0 +# +# for i, param in enumerate(net_params): +# if not requires_grads[i]: +# continue +# +# grad = tf.grad(loss, param) +# if has_clip: +# if self.clipping_type == self.ClippingType.Clamp: +# grad = tf.clamp(grad, -grad_clip, grad_clip) +# elif self.clipping_type == self.ClippingType.Norm: +# grad_norm = tf.max(1e-6, self.gradient_norm(grad)) +# grad = grad * tf.min(1.0, grad_clip / grad_norm) +# +# if self.optimizer_type == self.OptimizerType.ADAM: +# update = self.adam_update(i, param, grad, self.t, learning_rate) +# elif self.optimizer_type == self.OptimizerType.SGD: +# update = self.sgd_update(param, grad, learning_rate) +# elif self.optimizer_type == self.OptimizerType.RMSProp: +# update = self.rmsprop_update(i, param, grad, learning_rate) +# else: +# raise RuntimeError("Unknown optimizer type") +# +# # Apply regularization if needed +# if self.regularizer_type == self.RegularizerType.L1: +# param = param - learning_rate * self.reg * tf.sign(param) +# elif self.regularizer_type == self.RegularizerType.L2: +# param = param - learning_rate * self.reg * param +# +# # Update parameter with computed update +# param = param - update +# net_params[i] = param +# +# net.update_parameters(net_params) +# +# def adam_update(self, i, param, grad, t, learning_rate): +# beta1 = tf.float(self.beta1) +# beta2 = tf.float(self.beta2) +# +# m = self.m[i] +# v = self.v[i] +# +# m = tf.lerp(grad, m, beta1) +# v = tf.lerp(grad * grad, v, beta2) +# +# # t is a Parameter with shape [1]; get the scalar +# t_val = self.t[0] +# mhat = m / (1.0 - tf.pow(beta1, t_val)) +# vhat = v / (1.0 - tf.pow(beta2, t_val)) +# +# self.m[i] = m +# self.v[i] = v +# +# return learning_rate * mhat / (tf.sqrt(vhat) + self.epsilon) +# +# def sgd_update(self, param, grad, learning_rate): +# return learning_rate * grad +# +# def rmsprop_update(self, i, param, grad, learning_rate): +# decay = tf.float(self.decay) +# +# v = self.v[i] +# v = tf.lerp(grad * grad, v, decay) +# self.v[i] = v +# +# return (grad * learning_rate) / (tf.sqrt(v) + self.epsilon) +# +# +# def adam(net, reg_type=ModuleOptimizer.RegularizerType.None_, learning_rate=0.001, beta1=0.9, beta2=0.999, clip=0.0, reg=0.0): +# return ModuleOptimizer( +# ModuleOptimizer.OptimizerType.ADAM, +# reg_type, +# net, +# { +# "learning_rate": learning_rate, +# "beta1": beta1, +# "beta2": beta2, +# "grad_clip": clip, +# "reg": reg, +# } +# ) +# +# def sgd(net, reg_type=ModuleOptimizer.RegularizerType.None_, learning_rate=0.001, clip=0.0, reg=0.0): +# return ModuleOptimizer( +# ModuleOptimizer.OptimizerType.SGD, +# reg_type, +# net, +# { +# "learning_rate": learning_rate, +# "grad_clip": clip, +# "reg": reg, +# } +# ) +# +# def rmsprop(net, reg_type=ModuleOptimizer.RegularizerType.None_, learning_rate=0.001, decay=0.9, clip=0.0, reg=0.0): +# return ModuleOptimizer( +# ModuleOptimizer.OptimizerType.RMSProp, +# reg_type, +# net, +# { +# "learning_rate": learning_rate, +# "decay": decay, +# "grad_clip": clip, +# "reg": reg, +# } +# ) \ No newline at end of file diff --git a/Python/TensorFrost/random.py b/Python/TensorFrost/random.py index de3d0ad0..ab91f028 100644 --- a/Python/TensorFrost/random.py +++ b/Python/TensorFrost/random.py @@ -1,45 +1,45 @@ -from . import TensorFrost as tf - -def randn2(shape, seed=0): - #Box-Muller transform - r1 = tf.random_value(shape, seed=seed) - r2 = tf.random_value(shape, seed=tf.hash(seed)) - rho = tf.sqrt(-2.0*tf.log(tf.max(1e-6, r1))) - theta = 2.0*tf.pi*r2 - return rho*tf.cos(theta), rho*tf.sin(theta) - -def randn(shape, seed=0): - return randn2(shape, seed=seed)[0] - -def rand(shape, seed=0): - return tf.random_value(shape, seed=seed) - -def randn_like(tensor, seed=0): - return randn(tensor.shape, seed=seed) - -def rand_like(tensor, seed=0): - return rand(tensor.shape, seed=seed) - -def rand_int(seed, max_value): - return tf.int(tf.pcg(tf.uint(seed)) % tf.uint(max_value)) - -def xor_swap(idx, n, seed): - xor_seed = rand_int(seed, n) - xor_idx = (idx ^ xor_seed) - max_idx = tf.max(idx, xor_idx) - min_idx = tf.min(idx, xor_idx) - swap = rand_int(min_idx * 451 + seed, 2) == 0 - return tf.select(swap & (max_idx < n), xor_idx, idx) - -def reverse(idx, n): - return n - 1 - idx - -def shuffle(idx, n, seed = 0, iters = 16): - for i in range(iters): - idx = xor_swap(idx, n, seed + i) - idx = reverse(idx, n) - return idx - -def permutation(n, seed = 0): - idx = tf.indices([n])[0] - return shuffle(idx, n, seed) \ No newline at end of file +# from . import TensorFrost as tf +# +# def randn2(shape, seed=0): +# #Box-Muller transform +# r1 = tf.random_value(shape, seed=seed) +# r2 = tf.random_value(shape, seed=tf.hash(seed)) +# rho = tf.sqrt(-2.0*tf.log(tf.max(1e-6, r1))) +# theta = 2.0*tf.pi*r2 +# return rho*tf.cos(theta), rho*tf.sin(theta) +# +# def randn(shape, seed=0): +# return randn2(shape, seed=seed)[0] +# +# def rand(shape, seed=0): +# return tf.random_value(shape, seed=seed) +# +# def randn_like(tensor, seed=0): +# return randn(tensor.shape, seed=seed) +# +# def rand_like(tensor, seed=0): +# return rand(tensor.shape, seed=seed) +# +# def rand_int(seed, max_value): +# return tf.int(tf.pcg(tf.uint(seed)) % tf.uint(max_value)) +# +# def xor_swap(idx, n, seed): +# xor_seed = rand_int(seed, n) +# xor_idx = (idx ^ xor_seed) +# max_idx = tf.max(idx, xor_idx) +# min_idx = tf.min(idx, xor_idx) +# swap = rand_int(min_idx * 451 + seed, 2) == 0 +# return tf.select(swap & (max_idx < n), xor_idx, idx) +# +# def reverse(idx, n): +# return n - 1 - idx +# +# def shuffle(idx, n, seed = 0, iters = 16): +# for i in range(iters): +# idx = xor_swap(idx, n, seed + i) +# idx = reverse(idx, n) +# return idx +# +# def permutation(n, seed = 0): +# idx = tf.indices([n])[0] +# return shuffle(idx, n, seed) \ No newline at end of file diff --git a/Python/TensorFrost/regularizers.py b/Python/TensorFrost/regularizers.py index b384bcef..989767cc 100644 --- a/Python/TensorFrost/regularizers.py +++ b/Python/TensorFrost/regularizers.py @@ -1,5 +1,5 @@ -from .optimizers import * - -l1 = ModuleOptimizer.RegularizerType.L1 -l2 = ModuleOptimizer.RegularizerType.L2 -none = ModuleOptimizer.RegularizerType.None_ \ No newline at end of file +# from .optimizers import * +# +# l1 = ModuleOptimizer.RegularizerType.L1 +# l2 = ModuleOptimizer.RegularizerType.L2 +# none = ModuleOptimizer.RegularizerType.None_ \ No newline at end of file diff --git a/Python/TensorFrost/sort.py b/Python/TensorFrost/sort.py index 4d585c64..d3bfc143 100644 --- a/Python/TensorFrost/sort.py +++ b/Python/TensorFrost/sort.py @@ -1,187 +1,187 @@ -from . import TensorFrost as tf - -#in-place bitonic sort -def bitonic(keys, values = None): - tf.region_begin('Bitonic sort') - keys = tf.copy(keys) - if values is not None: - values = tf.copy(values) - element_count = keys.shape[0] - log2_count = tf.int(tf.ceil(tf.log2(tf.float(element_count)))) - count_round = 1 << log2_count - idx = tf.indices([count_round / 2])[0] - with tf.loop(log2_count) as k: - with tf.loop(k+1) as j: - s = 1 << (k-j) - m_inner = s - 1 - m_outer = ~m_inner - m_xor = s + tf.select(j == 0, m_inner, 0) - - id1 = (2 * (idx & m_outer) + (idx & m_inner)) - id2 = id1 ^ m_xor - key1, key2 = keys[id1], keys[id2] - with tf.if_cond((key1 >= key2) & (id1 < element_count) & (id2 < element_count)): - if values is not None: - val1, val2 = values[id1], values[id2] - values[id1] = val2 - values[id2] = val1 - keys[id1] = key2 - keys[id2] = key1 - - tf.region_end('Bitonic sort') - if values is not None: - return keys, values - else: - return keys - -#histogram radix sort -def radix(keys, values = None, bits_per_pass = 6, max_bits = 32): - def prefix_sum_grouped(A, axis = -1): - axis = len(A.shape) + axis if axis < 0 else axis - group_size = 64 - grouped = tf.split_dim(A, group_size, axis) - group_scan = tf.prefix_sum(tf.sum(grouped, axis = axis + 1), axis = axis) - ids = grouped.indices - gid, eid = ids[axis], ids[axis + 1] - ids = [ids[i] for i in range(len(ids)) if i != axis + 1] - ids[axis] = gid - 1 - group_scan = tf.prefix_sum(grouped + tf.select((gid == 0) | (eid != 0), tf.uint(0), group_scan[tuple(ids)]), axis = axis + 1) - full_scan = tf.merge_dim(group_scan, target_size = A.shape[axis], axis = axis + 1) - return full_scan - - sign_bit = ~tf.uint(0x7FFFFFFF) - - def map_float_to_uint(x): - # Convert float to uint representation - ux = tf.asuint(x) - # Compute mask - mask = tf.select((ux >> 31) == 1, ~tf.uint(0), sign_bit) - # Apply XOR - return ux ^ mask - - def map_uint_to_float(x): - # Compute mask - mask = tf.select((x >> 31) == 0, ~tf.uint(0), sign_bit) - # Apply XOR and convert back to float - return tf.asfloat(x ^ mask) - - def map_int_to_uint(x): - return tf.asuint(x) ^ sign_bit - - def map_uint_to_int(x): - return tf.asint(x ^ sign_bit) - - tf.region_begin('Radix sort') - - has_values = values is not None - - keys = tf.copy(keys) - if has_values: - values = tf.copy(values) - - original_type = keys.type - if(original_type == tf.float32): - keys = map_float_to_uint(keys) - - if(original_type == tf.int32): - keys = map_int_to_uint(keys) - - iters = (max_bits + bits_per_pass - 1) // bits_per_pass - group_size = 128 - histogram_size = 2 ** bits_per_pass - - def GetBits(A, i): - return (A >> (i * bits_per_pass)) & tf.uint(histogram_size - 1) - - keys1 = tf.buffer(keys.shape, keys.type) - values1 = None - - if has_values: - values1 = tf.buffer(values.shape, values.type) - - with tf.loop(iters // 2) as iter: - def SortIteration(keys_in, keys_out, values_in, values_out, iter): - tf.region_begin('Radix sort iteration') - grouped = tf.split_dim(GetBits(keys_in, iter), group_size) - - # Do a packed histogram, since we sum 128 elements at a time, we can pack 4 values into a single uint32 - g, e, i = tf.indices([grouped.shape[0], grouped.shape[1], tf.int(histogram_size/4)]) - this_key = grouped[g, e] - packed_is_bit = (tf.uint(this_key == tf.uint(4*i))) + (tf.uint(this_key == tf.uint(4*i+1)) << 8) + (tf.uint(this_key == tf.uint(4*i+2)) << 16) + (tf.uint(this_key == tf.uint(4*i+3)) << 24) - packed_is_bit = tf.select((g*group_size + e) < keys_in.shape[0], packed_is_bit, tf.uint(0)) - group_histogram_packed = tf.sum(packed_is_bit, axis = 1) - - g, i = tf.indices([grouped.shape[0], histogram_size]) - group_histogram = tf.uint((group_histogram_packed[g, i / 4] >> (8*(i % 4))) & tf.uint(0xFF)) - - group_histogram_scan = prefix_sum_grouped(group_histogram, axis = 0) - i, = tf.indices([histogram_size]) - total_bit_histogram = tf.prefix_sum(group_histogram_scan[group_histogram_scan.shape[0] - 1, i]) - - with tf.kernel(grouped.shape, group_size=[group_size]) as (g, e): - if(tf.current_backend() == tf.cpu): #dont use group barriers on CPU - doesn't work - element = g * group_size + e - with tf.if_cond(element < keys_in.shape[0]): - old_key = keys_in[element] - old_val = values_in[element] - bit = GetBits(old_key, iter) - total_offset = tf.select(g == 0, tf.uint(0), group_histogram_scan[g - 1, bit]) + tf.select(bit == tf.uint(0), tf.uint(0), total_bit_histogram[bit - tf.uint(1)]) - with tf.loop(e) as j: - total_offset.val += tf.uint(grouped[g, j] == bit) - keys_out[total_offset] = old_key - values_out[total_offset] = old_val - else: - temp = tf.group_buffer(group_size, tf.uint32) - half_count = tf.group_buffer(histogram_size, tf.uint32) - gtid = g.block_thread_index(0) - - #initialize counters - for i in range((histogram_size + group_size - 1) // group_size): - index = gtid + i * group_size - with tf.if_cond(index < histogram_size): - half_count[index] = 0 - tf.group_barrier() - - element = g * group_size + e - with tf.if_cond(element < keys_in.shape[0]): - old_key = keys_in[element] - bit = GetBits(old_key, iter) - temp[gtid] = bit - - #count number of bits set in previous sub groups - quarter_index = e / (group_size // 4) - with tf.if_cond(quarter_index < 3): - tf.scatterAdd(half_count[bit], tf.uint(quarter_index < 1) | (tf.uint(quarter_index < 2) << 8) | (tf.uint(quarter_index < 3) << 16)) - - tf.group_barrier() - - if has_values: - old_val = values_in[element] - - total_offset = tf.select(g == 0, tf.uint(0), group_histogram_scan[g - 1, tf.int(bit)]) + tf.select(tf.int(bit) == 0, tf.uint(0), total_bit_histogram[tf.int(bit) - 1]) - total_offset += tf.select(quarter_index > 0, (half_count[bit] >> (8*(quarter_index-1))) & tf.uint(0xFF), tf.uint(0)) - begin_index = quarter_index * (group_size // 4) - with tf.loop(begin_index, e) as j: - total_offset.val += tf.uint(temp[j] == bit) - keys_out[total_offset] = old_key - - if has_values: - values_out[total_offset] = old_val - - tf.region_end('Radix sort iteration') - - SortIteration(keys, keys1, values, values1, 2 * iter) - SortIteration(keys1, keys, values1, values, 2 * iter + 1) - - tf.region_end('Radix sort') - - if(original_type == tf.float32): - keys = map_uint_to_float(keys) - - if(original_type == tf.int32): - keys = map_uint_to_int(keys) - - if has_values: - return keys, values - else: - return keys +# from . import TensorFrost as tf +# +# #in-place bitonic sort +# def bitonic(keys, values = None): +# tf.region_begin('Bitonic sort') +# keys = tf.copy(keys) +# if values is not None: +# values = tf.copy(values) +# element_count = keys.shape[0] +# log2_count = tf.int(tf.ceil(tf.log2(tf.float(element_count)))) +# count_round = 1 << log2_count +# idx = tf.indices([count_round / 2])[0] +# with tf.loop(log2_count) as k: +# with tf.loop(k+1) as j: +# s = 1 << (k-j) +# m_inner = s - 1 +# m_outer = ~m_inner +# m_xor = s + tf.select(j == 0, m_inner, 0) +# +# id1 = (2 * (idx & m_outer) + (idx & m_inner)) +# id2 = id1 ^ m_xor +# key1, key2 = keys[id1], keys[id2] +# with tf.if_cond((key1 >= key2) & (id1 < element_count) & (id2 < element_count)): +# if values is not None: +# val1, val2 = values[id1], values[id2] +# values[id1] = val2 +# values[id2] = val1 +# keys[id1] = key2 +# keys[id2] = key1 +# +# tf.region_end('Bitonic sort') +# if values is not None: +# return keys, values +# else: +# return keys +# +# #histogram radix sort +# def radix(keys, values = None, bits_per_pass = 6, max_bits = 32): +# def prefix_sum_grouped(A, axis = -1): +# axis = len(A.shape) + axis if axis < 0 else axis +# group_size = 64 +# grouped = tf.split_dim(A, group_size, axis) +# group_scan = tf.prefix_sum(tf.sum(grouped, axis = axis + 1), axis = axis) +# ids = grouped.indices +# gid, eid = ids[axis], ids[axis + 1] +# ids = [ids[i] for i in range(len(ids)) if i != axis + 1] +# ids[axis] = gid - 1 +# group_scan = tf.prefix_sum(grouped + tf.select((gid == 0) | (eid != 0), tf.uint(0), group_scan[tuple(ids)]), axis = axis + 1) +# full_scan = tf.merge_dim(group_scan, target_size = A.shape[axis], axis = axis + 1) +# return full_scan +# +# sign_bit = ~tf.uint(0x7FFFFFFF) +# +# def map_float_to_uint(x): +# # Convert float to uint representation +# ux = tf.asuint(x) +# # Compute mask +# mask = tf.select((ux >> 31) == 1, ~tf.uint(0), sign_bit) +# # Apply XOR +# return ux ^ mask +# +# def map_uint_to_float(x): +# # Compute mask +# mask = tf.select((x >> 31) == 0, ~tf.uint(0), sign_bit) +# # Apply XOR and convert back to float +# return tf.asfloat(x ^ mask) +# +# def map_int_to_uint(x): +# return tf.asuint(x) ^ sign_bit +# +# def map_uint_to_int(x): +# return tf.asint(x ^ sign_bit) +# +# tf.region_begin('Radix sort') +# +# has_values = values is not None +# +# keys = tf.copy(keys) +# if has_values: +# values = tf.copy(values) +# +# original_type = keys.type +# if(original_type == tf.float32): +# keys = map_float_to_uint(keys) +# +# if(original_type == tf.int32): +# keys = map_int_to_uint(keys) +# +# iters = (max_bits + bits_per_pass - 1) // bits_per_pass +# group_size = 128 +# histogram_size = 2 ** bits_per_pass +# +# def GetBits(A, i): +# return (A >> (i * bits_per_pass)) & tf.uint(histogram_size - 1) +# +# keys1 = tf.buffer(keys.shape, keys.type) +# values1 = None +# +# if has_values: +# values1 = tf.buffer(values.shape, values.type) +# +# with tf.loop(iters // 2) as iter: +# def SortIteration(keys_in, keys_out, values_in, values_out, iter): +# tf.region_begin('Radix sort iteration') +# grouped = tf.split_dim(GetBits(keys_in, iter), group_size) +# +# # Do a packed histogram, since we sum 128 elements at a time, we can pack 4 values into a single uint32 +# g, e, i = tf.indices([grouped.shape[0], grouped.shape[1], tf.int(histogram_size/4)]) +# this_key = grouped[g, e] +# packed_is_bit = (tf.uint(this_key == tf.uint(4*i))) + (tf.uint(this_key == tf.uint(4*i+1)) << 8) + (tf.uint(this_key == tf.uint(4*i+2)) << 16) + (tf.uint(this_key == tf.uint(4*i+3)) << 24) +# packed_is_bit = tf.select((g*group_size + e) < keys_in.shape[0], packed_is_bit, tf.uint(0)) +# group_histogram_packed = tf.sum(packed_is_bit, axis = 1) +# +# g, i = tf.indices([grouped.shape[0], histogram_size]) +# group_histogram = tf.uint((group_histogram_packed[g, i / 4] >> (8*(i % 4))) & tf.uint(0xFF)) +# +# group_histogram_scan = prefix_sum_grouped(group_histogram, axis = 0) +# i, = tf.indices([histogram_size]) +# total_bit_histogram = tf.prefix_sum(group_histogram_scan[group_histogram_scan.shape[0] - 1, i]) +# +# with tf.kernel(grouped.shape, group_size=[group_size]) as (g, e): +# if(tf.current_backend() == tf.cpu): #dont use group barriers on CPU - doesn't work +# element = g * group_size + e +# with tf.if_cond(element < keys_in.shape[0]): +# old_key = keys_in[element] +# old_val = values_in[element] +# bit = GetBits(old_key, iter) +# total_offset = tf.select(g == 0, tf.uint(0), group_histogram_scan[g - 1, bit]) + tf.select(bit == tf.uint(0), tf.uint(0), total_bit_histogram[bit - tf.uint(1)]) +# with tf.loop(e) as j: +# total_offset.val += tf.uint(grouped[g, j] == bit) +# keys_out[total_offset] = old_key +# values_out[total_offset] = old_val +# else: +# temp = tf.group_buffer(group_size, tf.uint32) +# half_count = tf.group_buffer(histogram_size, tf.uint32) +# gtid = g.block_thread_index(0) +# +# #initialize counters +# for i in range((histogram_size + group_size - 1) // group_size): +# index = gtid + i * group_size +# with tf.if_cond(index < histogram_size): +# half_count[index] = 0 +# tf.group_barrier() +# +# element = g * group_size + e +# with tf.if_cond(element < keys_in.shape[0]): +# old_key = keys_in[element] +# bit = GetBits(old_key, iter) +# temp[gtid] = bit +# +# #count number of bits set in previous sub groups +# quarter_index = e / (group_size // 4) +# with tf.if_cond(quarter_index < 3): +# tf.scatterAdd(half_count[bit], tf.uint(quarter_index < 1) | (tf.uint(quarter_index < 2) << 8) | (tf.uint(quarter_index < 3) << 16)) +# +# tf.group_barrier() +# +# if has_values: +# old_val = values_in[element] +# +# total_offset = tf.select(g == 0, tf.uint(0), group_histogram_scan[g - 1, tf.int(bit)]) + tf.select(tf.int(bit) == 0, tf.uint(0), total_bit_histogram[tf.int(bit) - 1]) +# total_offset += tf.select(quarter_index > 0, (half_count[bit] >> (8*(quarter_index-1))) & tf.uint(0xFF), tf.uint(0)) +# begin_index = quarter_index * (group_size // 4) +# with tf.loop(begin_index, e) as j: +# total_offset.val += tf.uint(temp[j] == bit) +# keys_out[total_offset] = old_key +# +# if has_values: +# values_out[total_offset] = old_val +# +# tf.region_end('Radix sort iteration') +# +# SortIteration(keys, keys1, values, values1, 2 * iter) +# SortIteration(keys1, keys, values1, values, 2 * iter + 1) +# +# tf.region_end('Radix sort') +# +# if(original_type == tf.float32): +# keys = map_uint_to_float(keys) +# +# if(original_type == tf.int32): +# keys = map_uint_to_int(keys) +# +# if has_values: +# return keys, values +# else: +# return keys diff --git a/TensorFrost/include/Compiler/Operation.h b/TensorFrost/include/Compiler/Operation.h index b006ec3e..9f8b823d 100644 --- a/TensorFrost/include/Compiler/Operation.h +++ b/TensorFrost/include/Compiler/Operation.h @@ -17,10 +17,7 @@ struct Op { std::vector> blocks; Op(std::string op_name); - Op(int value); - Op(uint value); - Op(float value); - Op(bool value); + OpBlock* NewBlock(); }; diff --git a/TensorFrost/include/Compiler/OperationBlocks.h b/TensorFrost/include/Compiler/OperationBlocks.h index 7ad8705f..a09b8efc 100644 --- a/TensorFrost/include/Compiler/OperationBlocks.h +++ b/TensorFrost/include/Compiler/OperationBlocks.h @@ -4,8 +4,11 @@ namespace TensorFrost { struct OpBlock { + Op* parent_op = nullptr; std::list> ops; Op* append(std::unique_ptr op); + + OpBlock(Op* parent = nullptr); }; class OpBlockIterator { diff --git a/TensorFrost/src/Compiler/ExecutionContext.cpp b/TensorFrost/src/Compiler/ExecutionContext.cpp index 1e1c6a9f..dd49b740 100644 --- a/TensorFrost/src/Compiler/ExecutionContext.cpp +++ b/TensorFrost/src/Compiler/ExecutionContext.cpp @@ -7,7 +7,7 @@ ExecutionContext::ExecutionContext(): base_block(std::make_unique()), c void ExecutionContext::BeginBlock(Op *op) { stack.push_back(current_block); - current_block = new OpBlock(); + current_block = op->NewBlock(); } void ExecutionContext::EndBlock() { diff --git a/TensorFrost/src/Compiler/Operation.cpp b/TensorFrost/src/Compiler/Operation.cpp index 96e96914..bcd0a31f 100644 --- a/TensorFrost/src/Compiler/Operation.cpp +++ b/TensorFrost/src/Compiler/Operation.cpp @@ -6,23 +6,8 @@ Op::Op(std::string op_name): opcode(std::move(op_name)) { type = TFTypeNone; } -Op::Op(int value) : Op("const") { - attributes["value"] = value; - type = TFTypeInt32; +OpBlock* Op::NewBlock() { + blocks.emplace_back(std::make_unique(this)); + return blocks.back().get(); } - -Op::Op(uint value) : Op("const") { - attributes["value"] = value; - type = TFTypeUint32; -} - -Op::Op(float value) : Op("const") { - attributes["value"] = value; - type = TFTypeFloat32; -} - -Op::Op(bool value) : Op(std::string("const")) { - attributes["value"] = value; - type = TFTypeBool32; } -} \ No newline at end of file diff --git a/TensorFrost/src/Compiler/OperationBlocks.cpp b/TensorFrost/src/Compiler/OperationBlocks.cpp index 6af9901f..d53072d0 100644 --- a/TensorFrost/src/Compiler/OperationBlocks.cpp +++ b/TensorFrost/src/Compiler/OperationBlocks.cpp @@ -6,6 +6,8 @@ Op* OpBlock::append(std::unique_ptr op) { return ops.back().get(); } +OpBlock::OpBlock(Op *parent): parent_op(parent) {} + OpBlockIterator::OpBlockIterator(OpBlock* root) : current_op(nullptr) { if (root && !root->ops.empty()) { stack.push_back({root, root->ops.begin(), root->ops.end()}); diff --git a/TensorFrost/src/Compiler/OperationRegistry.cpp b/TensorFrost/src/Compiler/OperationRegistry.cpp index a33ca7ae..b04f4ae6 100644 --- a/TensorFrost/src/Compiler/OperationRegistry.cpp +++ b/TensorFrost/src/Compiler/OperationRegistry.cpp @@ -56,6 +56,8 @@ OverloadsMap ovr(const std::string& input) { } vector default_operations = { + OpSpec("const", ovr("f(); u(); i(); b(); tuple()")), + OpSpec("add", ovr("f(f,f); u(u,u); i(i,i)")), OpSpec("sub", ovr("f(f,f); u(u,u); i(i,i)")), OpSpec("mul", ovr("f(f,f); u(u,u); i(i,i)")), diff --git a/examples/debug.py b/examples/debug.py index 8037c7bb..0c24adae 100644 --- a/examples/debug.py +++ b/examples/debug.py @@ -1,46 +1,2 @@ import numpy as np import TensorFrost as tf - -tf.initialize(tf.cpu) - -blur_d = 16 -blur_r = blur_d * 0.5 - -def kernel(r): - #return 1.0 - return tf.exp(-0.5 * (r / blur_r)**2.0) / (blur_r * np.sqrt(2.0 * np.pi)) - -# def blurfunc(): -# img = tf.func_input([-1, -1], tf.float32) -# with tf.loop(-blur_d, blur_d+1, 1) as k: -# blur_h += img[i+k, j] * kernel(tf.float(k)) -# with tf.loop(-blur_d, blur_d+1, 1) as k: -# blur_v += blur_h[i, j+k] * kernel(tf.float(k)) -# return blur_v -# -# def blur(): -# img = tf.input([-1, -1, -1], tf.float32) -# N, M, C = img.shape -# blur_h = tf.zeros(img.shape, tf.float32) -# blur_v = tf.zeros(img.shape, tf.float32) -# i, j, ch = img.indices -# -# tf.vmap(inputs=[img], map=[C], func=blurfunc); -# -# return blur_v - -@tf.compile -def blur(img: tf.Arg([-1, -1, -1], tf.float32)): - blur_h = tf.zeros(img.shape, tf.float32) - blur_v = tf.zeros(img.shape, tf.float32) - i, j, ch = img.indices - - #horizontal blur - with tf.loop(-blur_d, blur_d+1, 1) as k: - blur_h += img[i+k, j, ch] * kernel(tf.float(k)) - - #vertical blur - with tf.loop(-blur_d, blur_d+1, 1) as k: - blur_v += blur_h[i, j+k, ch] * kernel(tf.float(k)) - - return blur_v From 71e42d30dd0a96eda202db4d55b7dff2d0f6e13a Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz Date: Sun, 1 Jun 2025 06:05:58 +0200 Subject: [PATCH 06/44] Basic IR printing --- TensorFrost/PybindModule.cpp | 6 ++- .../include/Compiler/OperationArguments.h | 2 + TensorFrost/include/Compiler/Printer.h | 9 ++++ TensorFrost/include/TensorFrost.h | 3 +- .../src/Compiler/OperationArguments.cpp | 6 ++- TensorFrost/src/Compiler/Printer.cpp | 42 +++++++++++++++++++ 6 files changed, 64 insertions(+), 4 deletions(-) create mode 100644 TensorFrost/include/Compiler/Printer.h create mode 100644 TensorFrost/src/Compiler/Printer.cpp diff --git a/TensorFrost/PybindModule.cpp b/TensorFrost/PybindModule.cpp index 258448bd..4f9d4831 100644 --- a/TensorFrost/PybindModule.cpp +++ b/TensorFrost/PybindModule.cpp @@ -136,8 +136,10 @@ PYBIND11_MODULE(TensorFrost, m) { StartExecutionContext(); Op& a = constant(5); Op& b = constant(10); - Op& c = a + b; - py::print("Created operation: ", c.opcode); + Op& c = a + b * 3; + std::string tree = PrintTree(*GetContext()->base_block.get()); + py::print("Created operation tree:"); + py::print(tree); } } // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/include/Compiler/OperationArguments.h b/TensorFrost/include/Compiler/OperationArguments.h index bcc42f00..d411a875 100644 --- a/TensorFrost/include/Compiler/OperationArguments.h +++ b/TensorFrost/include/Compiler/OperationArguments.h @@ -37,6 +37,8 @@ struct ArgumentManager { void AddArgument(Op* from, ArgType type, int index = 0); void SetAsOutput(Argument *arg); void SetArguments(ArgType type, std::vector args); + + Arguments* GetArguments(ArgType type) const; }; } \ No newline at end of file diff --git a/TensorFrost/include/Compiler/Printer.h b/TensorFrost/include/Compiler/Printer.h new file mode 100644 index 00000000..4d24dcb2 --- /dev/null +++ b/TensorFrost/include/Compiler/Printer.h @@ -0,0 +1,9 @@ +#pragma once +#include "Operation.h" +#include "OperationBlocks.h" + +namespace TensorFrost { + +void PrintOp(const Op& op, std::ostream& os); +std::string PrintTree(const OpBlock& base_block); +} diff --git a/TensorFrost/include/TensorFrost.h b/TensorFrost/include/TensorFrost.h index 140eb362..4fa8ac4d 100644 --- a/TensorFrost/include/TensorFrost.h +++ b/TensorFrost/include/TensorFrost.h @@ -4,4 +4,5 @@ #include "Compiler/ExecutionContext.h" #include "Compiler/OperationBlocks.h" #include "Compiler/OperationArguments.h" -#include "Compiler/Overloads.h" \ No newline at end of file +#include "Compiler/Overloads.h" +#include "Compiler/Printer.h" \ No newline at end of file diff --git a/TensorFrost/src/Compiler/OperationArguments.cpp b/TensorFrost/src/Compiler/OperationArguments.cpp index 1edc31c6..bf3782d8 100644 --- a/TensorFrost/src/Compiler/OperationArguments.cpp +++ b/TensorFrost/src/Compiler/OperationArguments.cpp @@ -40,4 +40,8 @@ void ArgumentManager::SetArguments(ArgType type, std::vector args) { AddArgument(args[i], type, (int)i); } } -} \ No newline at end of file + +Arguments * ArgumentManager::GetArguments(ArgType type) const { + return type_args[(int)type].get(); +} +} diff --git a/TensorFrost/src/Compiler/Printer.cpp b/TensorFrost/src/Compiler/Printer.cpp new file mode 100644 index 00000000..fc4a97a8 --- /dev/null +++ b/TensorFrost/src/Compiler/Printer.cpp @@ -0,0 +1,42 @@ +#include "Compiler/Operation.h" +#include "Compiler/ExecutionContext.h" +#include "Compiler/Printer.h" + +using namespace std; + +namespace TensorFrost { +void PrintOp(const Op &op, std::ostringstream &os) { + os << "Op: " << op.opcode << "\n"; + os << "Type: " << ToString(op.type) << "\n"; + os << "Arguments:\n"; + for (int i = 0; i < (int)ArgType::Count; ++i) { + const auto args = op.args->GetArguments(static_cast(i)); + if (args) { + os << " " << ToString(static_cast(i)) << ":\n"; + for (const auto& arg : args->inputs) { + if (arg) { + os << " From: " << (arg->from ? arg->from->opcode : "nullptr") + << ", Index: " << arg->index << "\n"; + } + } + } + } + os << "Attributes:\n"; + for (const auto& [key, value] : op.attributes) { + os << " " << key << ": "; + std::visit([&os](const auto& v) { os << v; }, value); + os << "\n"; + } +} + +std::string PrintTree(const OpBlock &base_block) { + OpBlockIterator it(const_cast(&base_block)); + std::ostringstream oss; + while (Op* op = it.next()) { + PrintOp(*op, oss); + oss << "\n"; + } + return oss.str(); +} + +} From 0c1885ba8ff0866b2744801eb7bb50d71c0aad81 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz Date: Sun, 1 Jun 2025 22:15:53 +0200 Subject: [PATCH 07/44] Various improvements --- ProtoIR.txt | 83 +------------- TensorFrost/PybindModule.cpp | 6 +- TensorFrost/include/Compiler/Common.h | 27 +++-- .../include/Compiler/ExecutionContext.h | 19 ++- TensorFrost/include/Compiler/Operation.h | 4 + .../include/Compiler/OperationArguments.h | 12 +- .../include/Compiler/OperationBlocks.h | 51 +++++---- .../include/Compiler/OperationRegistry.h | 3 +- TensorFrost/include/Compiler/Overloads.h | 27 ++--- TensorFrost/include/Compiler/Printer.h | 4 +- TensorFrost/src/Compiler/ExecutionContext.cpp | 72 +++++++++--- .../src/Compiler/OperationArguments.cpp | 43 ++++++- TensorFrost/src/Compiler/OperationBlocks.cpp | 108 ++++++------------ .../src/Compiler/OperationRegistry.cpp | 5 +- TensorFrost/src/Compiler/Overloads.cpp | 27 ++++- TensorFrost/src/Compiler/Printer.cpp | 101 ++++++++++++---- 16 files changed, 339 insertions(+), 253 deletions(-) diff --git a/ProtoIR.txt b/ProtoIR.txt index b48dbdf4..ac6519dc 100644 --- a/ProtoIR.txt +++ b/ProtoIR.txt @@ -13,85 +13,4 @@ ids = parallel(args{shape=[n, n]}, attributes{type=tuple}) { a0 = load(args{input=[b], indices=[i]}, attributes{type=float32}) a1 = load(args{input=[b], indices=[j]}, attributes{type=float32}) outer = mul(args{input=[a0, a1]}, attributes{type=float32, output_index=0}) -} - - -vmap (shape = [a, b, c]) { - A_2 = load(memory=[A], indices=[i,j]) - int(32) v2_0(32) = const(outputs=[v2_1, v2_2, ], data=[1], index=15, debug_index=32, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - float(32) v2_1(32) = dim_norm(inputs=[A_2], outputs=[v2_2, ], data=[0], shape=[v2_0(1)], index=16, debug_index=34, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - store(memory=[R(0.f)], inputs=[v2_1], indices=[i,i], shape=[v2_0(1)], index=17, debug_index=36, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - -} - -int(32) v1_0(32) = const(data=[4294967295], index=1, debug_index=4, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -int(32) v1_1(32) = const(data=[4294967295], index=2, debug_index=6, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -int(32) n(32) = input_shape(outputs=[v1_3, v2_8, v3_5, v3_1, v3_12, v3_7, v3_10, v3_17, v3_14, A, Q(0.f), R(0.f), R(0.f), ], flags={InputShapeDim(1), }, index=3, debug_index=8, debug_name=n, created_in=Tracing initial graph, created_in_func=None, ) -int(32) m(32) = input_shape(outputs=[v2_9, A_2, j, A_3, v2_4, v2_3, A_6, A_7, v3_15, v3_18, A, Q(0.f), ], flags={InputShapeDim(0), }, index=4, debug_index=10, debug_name=m, created_in=Tracing initial graph, created_in_func=None, ) -float(32) A(32) = memory(outputs=[A_2, A_3, A_5, v2_20, A_4, A_6, A_7, ], flags={Modified, InputMemory(0), }, shape=[n,m], index=5, debug_index=12, debug_name=A, created_in=Tracing initial graph, created_in_func=None, ) -float(32) Q(32) = const(outputs=[Q_2, v2_4, Q_3, v3_18, ], data=[0], flags={Modified, OutputMemory(0), }, shape=[n,m], index=6, debug_index=14, debug_name=Q, created_in=Tracing initial graph, created_in_func=None, ) -float(32) R(32) = const(outputs=[v2_2, R_2, v2_17, R_3, v3_8, R_4, ], data=[0], flags={Modified, OutputMemory(1), }, shape=[n,n], index=7, debug_index=16, debug_name=R, created_in=Tracing initial graph, created_in_func=None, ) -int(32) j(32) = dim_id(outputs=[A_2, A_3, v2_4, A_6, A_7, v3_18, ], data=[0], shape=[m], index=8, debug_index=18, debug_name=j, created_in=Tracing initial graph, created_in_func=None, ) -int(32) v1_2(32) = const(outputs=[v1_3, ], data=[1], index=9, debug_index=20, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -int(32) v1_3(32) = sub(inputs=[n,v1_2(1)], outputs=[i, ], index=10, debug_index=22, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -int(32) v1_4(32) = const(outputs=[i, ], data=[1], index=11, debug_index=24, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -int(32) v1_5(32) = const(outputs=[i, ], data=[0], index=12, debug_index=26, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -int(32) i(32) = loop(inputs=[v1_5(0),v1_3,v1_4(1)], outputs=[v2_6, v2_14, A_2, v2_2, Q_2, R_2, A_3, v2_4, Q_3, v2_2, R_2, v2_17, R_3, ], index=13, debug_index=28, debug_name=i, created_in=Tracing initial graph, created_in_func=None, ) -{ - float(32) A_2(32) = load(memory=[A], indices=[i,j], outputs=[v2_1, ], data=[0], shape=[m], index=14, debug_index=29, debug_name=A, created_in=Tracing initial graph, created_in_func=None, ) - int(32) v2_0(32) = const(outputs=[v2_1, v2_2, ], data=[1], index=15, debug_index=32, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - float(32) v2_1(32) = dim_norm(inputs=[A_2], outputs=[v2_2, ], data=[0], shape=[v2_0(1)], index=16, debug_index=34, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - store(memory=[R(0.f)], inputs=[v2_1], indices=[i,i], shape=[v2_0(1)], index=17, debug_index=36, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - float(32) A_3(32) = load(memory=[A], indices=[i,j], outputs=[v2_3, ], data=[0], shape=[m], index=18, debug_index=38, debug_name=A, created_in=Tracing initial graph, created_in_func=None, ) - float(32) R_2(32) = load(memory=[R(0.f)], indices=[i,i], outputs=[v2_3, ], data=[0], index=19, debug_index=40, debug_name=R, created_in=Tracing initial graph, created_in_func=None, ) - float(32) v2_3(32) = div(inputs=[A_3,R_2], outputs=[v2_4, ], shape=[m], index=20, debug_index=42, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - store(memory=[Q(0.f)], inputs=[v2_3], indices=[i,j], shape=[m], index=21, debug_index=44, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - int(32) v2_5(32) = const(outputs=[v2_6, ], data=[1], index=22, debug_index=46, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - int(32) v2_6(32) = add(inputs=[i,v2_5(1)], outputs=[v2_8, k, ], index=23, debug_index=48, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - int(32) v2_7(32) = const(outputs=[p, v2_9, ], data=[0], index=24, debug_index=50, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - int(32) v2_8(32) = sub(inputs=[n,v2_6], outputs=[p, Q_2, v2_10, v2_11, k, v2_17, Q_3, v2_13, v2_16, A_5, v2_12, v2_18, v2_14, v2_19, v2_20, R_3, A_4, dot, ], index=25, debug_index=52, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - int(32) v2_9(32) = sub(inputs=[m,v2_7(0)], outputs=[p, Q_2, v2_10, v2_11, k, Q_3, A_5, v2_12, v2_18, v2_19, v2_20, R_3, A_4, ], index=26, debug_index=54, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - int(32) v2_10(32) = dim_id(outputs=[k, ], data=[0], shape=[v2_8,v2_9], index=27, debug_index=56, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - int(32) k(32) = add(inputs=[v2_10,v2_6], outputs=[A_5, v2_20, R_3, A_4, ], shape=[v2_8,v2_9], index=28, debug_index=58, debug_name=k, created_in=Tracing initial graph, created_in_func=None, ) - int(32) v2_11(32) = dim_id(outputs=[p, ], data=[1], shape=[v2_8,v2_9], index=29, debug_index=60, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - int(32) p(32) = add(inputs=[v2_11,v2_7(0)], outputs=[Q_2, Q_3, A_5, v2_20, A_4, ], shape=[v2_8,v2_9], index=30, debug_index=62, debug_name=p, created_in=Tracing initial graph, created_in_func=None, ) - float(32) Q_2(32) = load(memory=[Q(0.f)], indices=[i,p], outputs=[v2_12, ], data=[0], shape=[v2_8,v2_9], index=31, debug_index=64, debug_name=Q, created_in=Tracing initial graph, created_in_func=None, ) - float(32) A_4(32) = load(memory=[A], indices=[k,p], outputs=[v2_12, ], data=[0], shape=[v2_8,v2_9], index=32, debug_index=66, debug_name=A, created_in=Tracing initial graph, created_in_func=None, ) - float(32) v2_12(32) = mul(inputs=[Q_2,A_4], outputs=[dot, ], shape=[v2_8,v2_9], index=33, debug_index=68, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - float(32) dot(32) = dim_sum(inputs=[v2_12], outputs=[v2_17, ], data=[1], shape=[v2_8], index=34, debug_index=70, debug_name=dot, created_in=Tracing initial graph, created_in_func=None, ) - int(32) v2_13(32) = dim_id(outputs=[v2_14, ], data=[0], shape=[v2_8], index=35, debug_index=72, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - int(32) v2_14(32) = add(inputs=[v2_13,i], outputs=[v2_16, ], shape=[v2_8], index=36, debug_index=74, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - int(32) v2_15(32) = const(outputs=[v2_16, ], data=[1], index=37, debug_index=76, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - int(32) v2_16(32) = add(inputs=[v2_14,v2_15(1)], outputs=[v2_17, ], shape=[v2_8], index=38, debug_index=78, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - store(memory=[R(0.f)], inputs=[dot], indices=[v2_16,i], shape=[v2_8], index=39, debug_index=80, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - float(32) A_5(32) = load(memory=[A], indices=[k,p], outputs=[v2_19, ], data=[0], shape=[v2_8,v2_9], index=40, debug_index=82, debug_name=A, created_in=Tracing initial graph, created_in_func=None, ) - float(32) Q_3(32) = load(memory=[Q(0.f)], indices=[i,p], outputs=[v2_18, ], data=[0], shape=[v2_8,v2_9], index=41, debug_index=84, debug_name=Q, created_in=Tracing initial graph, created_in_func=None, ) - float(32) R_3(32) = load(memory=[R(0.f)], indices=[k,i], outputs=[v2_18, ], data=[0], shape=[v2_8,v2_9], index=42, debug_index=86, debug_name=R, created_in=Tracing initial graph, created_in_func=None, ) - float(32) v2_18(32) = mul(inputs=[Q_3,R_3], outputs=[v2_19, ], shape=[v2_8,v2_9], index=43, debug_index=88, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - float(32) v2_19(32) = sub(inputs=[A_5,v2_18], outputs=[v2_20, ], shape=[v2_8,v2_9], index=44, debug_index=90, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) - store(memory=[A], inputs=[v2_19], indices=[k,p], shape=[v2_8,v2_9], index=45, debug_index=92, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -} -int(32) v3_0(32) = const(outputs=[v3_1, ], data=[1], index=46, debug_index=30, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -int(32) v3_1(32) = sub(inputs=[n,v3_0(1)], outputs=[A_6, ], index=47, debug_index=96, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -float(32) A_6(32) = load(memory=[A], indices=[v3_1,j], outputs=[v3_3, ], data=[0], shape=[m], index=48, debug_index=98, debug_name=A, created_in=Tracing initial graph, created_in_func=None, ) -int(32) v3_2(32) = const(outputs=[v3_8, v3_3, ], data=[1], index=49, debug_index=100, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -float(32) v3_3(32) = dim_norm(inputs=[A_6], outputs=[v3_8, ], data=[0], shape=[v3_2(1)], index=50, debug_index=102, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -int(32) v3_4(32) = const(outputs=[v3_5, ], data=[1], index=51, debug_index=104, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -int(32) v3_5(32) = sub(inputs=[n,v3_4(1)], outputs=[v3_8, ], index=52, debug_index=106, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -int(32) v3_6(32) = const(outputs=[v3_7, ], data=[1], index=53, debug_index=108, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -int(32) v3_7(32) = sub(inputs=[n,v3_6(1)], outputs=[v3_8, ], index=54, debug_index=110, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -store(memory=[R(0.f)], inputs=[v3_3], indices=[v3_7,v3_5], shape=[v3_2(1)], index=55, debug_index=112, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -int(32) v3_9(32) = const(outputs=[v3_10, ], data=[1], index=56, debug_index=114, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -int(32) v3_10(32) = sub(inputs=[n,v3_9(1)], outputs=[A_7, ], index=57, debug_index=116, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -float(32) A_7(32) = load(memory=[A], indices=[v3_10,j], outputs=[v3_15, ], data=[0], shape=[m], index=58, debug_index=118, debug_name=A, created_in=Tracing initial graph, created_in_func=None, ) -int(32) v3_11(32) = const(outputs=[v3_12, ], data=[1], index=59, debug_index=120, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -int(32) v3_12(32) = sub(inputs=[n,v3_11(1)], outputs=[R_4, ], index=60, debug_index=122, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -int(32) v3_13(32) = const(outputs=[v3_14, ], data=[1], index=61, debug_index=124, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -int(32) v3_14(32) = sub(inputs=[n,v3_13(1)], outputs=[R_4, ], index=62, debug_index=126, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -float(32) R_4(32) = load(memory=[R(0.f)], indices=[v3_14,v3_12], outputs=[v3_15, ], data=[0], index=63, debug_index=128, debug_name=R, created_in=Tracing initial graph, created_in_func=None, ) -float(32) v3_15(32) = div(inputs=[A_7,R_4], outputs=[v3_18, ], shape=[m], index=64, debug_index=130, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -int(32) v3_16(32) = const(outputs=[v3_17, ], data=[1], index=65, debug_index=132, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -int(32) v3_17(32) = sub(inputs=[n,v3_16(1)], outputs=[v3_18, ], index=66, debug_index=134, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -store(memory=[Q(0.f)], inputs=[v3_15], indices=[v3_17,j], shape=[m], index=67, debug_index=136, debug_name=, created_in=Tracing initial graph, created_in_func=None, ) -region_end(index=68, debug_index=138, debug_name=blur, created_in=Tracing initial graph, created_in_func=None, ) - +} \ No newline at end of file diff --git a/TensorFrost/PybindModule.cpp b/TensorFrost/PybindModule.cpp index 4f9d4831..b5133c99 100644 --- a/TensorFrost/PybindModule.cpp +++ b/TensorFrost/PybindModule.cpp @@ -137,7 +137,11 @@ PYBIND11_MODULE(TensorFrost, m) { Op& a = constant(5); Op& b = constant(10); Op& c = a + b * 3; - std::string tree = PrintTree(*GetContext()->base_block.get()); + vmap({&a, &b, &c}, [&](Op* op) { + Op& d = c + b; + }); + AssignVariableNames(*GetBaseBlock()); + std::string tree = PrintBlock(*GetBaseBlock()); py::print("Created operation tree:"); py::print(tree); } diff --git a/TensorFrost/include/Compiler/Common.h b/TensorFrost/include/Compiler/Common.h index 9b0d53d8..69e37d64 100644 --- a/TensorFrost/include/Compiler/Common.h +++ b/TensorFrost/include/Compiler/Common.h @@ -9,6 +9,8 @@ #include #include #include +#include +#include namespace TensorFrost { extern "C" { @@ -83,13 +85,13 @@ inline std::string ToString(ArgType type) { inline std::string ToString(const TFDataFormat& format) { switch (format.type) { - case TFType::Float: return "Float" + std::to_string(format.size); - case TFType::Uint: return "Uint" + std::to_string(format.size); - case TFType::Int: return "Int" + std::to_string(format.size); - case TFType::Bool: return "Bool" + std::to_string(format.size); - case TFType::Tuple: return "Tuple"; - case TFType::None: return "None"; - default: return "Unknown"; + case TFType::Float: return "float" + std::to_string(format.size); + case TFType::Uint: return "uint" + std::to_string(format.size); + case TFType::Int: return "int" + std::to_string(format.size); + case TFType::Bool: return "bool" + std::to_string(format.size); + case TFType::Tuple: return "tuple"; + case TFType::None: return "none"; + default: return "unknown"; } } @@ -106,8 +108,19 @@ struct Argument; using Attribute = std::variant; using AttributeMap = std::unordered_map; +//ostringstream conversion for Attribute +inline std::ostream& operator<<(std::ostream& os, const Attribute& attr) { + std::visit([&os](const auto& v) { os << v; }, attr); + return os; +} + +inline std::string ToString(const Attribute& attr) { + std::ostringstream oss; + oss << attr; + return oss.str(); } +} namespace std { template<> struct hash { diff --git a/TensorFrost/include/Compiler/ExecutionContext.h b/TensorFrost/include/Compiler/ExecutionContext.h index 7d8a5292..55805a9d 100644 --- a/TensorFrost/include/Compiler/ExecutionContext.h +++ b/TensorFrost/include/Compiler/ExecutionContext.h @@ -1,24 +1,31 @@ #pragma once #include "Common.h" +#include "Operation.h" namespace TensorFrost { struct ExecutionContext { std::unique_ptr base_block; - OpBlock* current_block; - std::vector stack; + OpBlock::Iterator cursor; + std::stack stack; ExecutionContext(); + void BeginCursor(OpBlock::Iterator it); + void EndCursor(); - void BeginBlock(Op* op); - void EndBlock(); - - Op &AddOp(std::unique_ptr op); + Op &Add(std::unique_ptr op); + Op &AddBeforeCursor(std::unique_ptr op); }; void StartExecutionContext(); ExecutionContext* GetContext(); +OpBlock* GetBaseBlock(); +OpBlock* GetCurrentBlock(); +void BeginCursor(OpBlock::Iterator it); +void BeginCursor(OpBlock& block); +void BeginCursor(Op* op); +void EndCursor(); void EndExecutionContext(); } diff --git a/TensorFrost/include/Compiler/Operation.h b/TensorFrost/include/Compiler/Operation.h index 9f8b823d..5e2eaf36 100644 --- a/TensorFrost/include/Compiler/Operation.h +++ b/TensorFrost/include/Compiler/Operation.h @@ -10,7 +10,11 @@ namespace TensorFrost { struct Op { + OpBlock* parent_block = nullptr; + + size_t index = 0; //might not be up to date std::string opcode; + std::string varname; std::unique_ptr args; AttributeMap attributes; TFDataFormat type; diff --git a/TensorFrost/include/Compiler/OperationArguments.h b/TensorFrost/include/Compiler/OperationArguments.h index d411a875..01e304d4 100644 --- a/TensorFrost/include/Compiler/OperationArguments.h +++ b/TensorFrost/include/Compiler/OperationArguments.h @@ -17,6 +17,10 @@ struct Arguments { void AddInput(ArgType type, Op* from, int index = 0); bool CheckValidity(bool throw_error = false) const; + void RemoveInput(int index); + + std::vector Args() const; + std::vector Inputs() const; }; struct ShapeArgs : Arguments { @@ -24,9 +28,7 @@ struct ShapeArgs : Arguments { float GetSizeEstimate(); void ExpandDimensionsTo(int new_dim); - bool CompareShape(const ShapeArgs& other, bool throw_error = false) const { - //TODO: Implement shape comparison logic - } + bool CompareShape(const ShapeArgs& other, bool throw_error = false) const; }; struct ArgumentManager { @@ -36,9 +38,11 @@ struct ArgumentManager { ArgumentManager(Op* parent); void AddArgument(Op* from, ArgType type, int index = 0); void SetAsOutput(Argument *arg); + void RemoveOutput(Argument *arg); void SetArguments(ArgType type, std::vector args); - Arguments* GetArguments(ArgType type) const; + Arguments* Get(ArgType type) const; + Arguments* operator[](ArgType type) const; }; } \ No newline at end of file diff --git a/TensorFrost/include/Compiler/OperationBlocks.h b/TensorFrost/include/Compiler/OperationBlocks.h index a09b8efc..ba76cf87 100644 --- a/TensorFrost/include/Compiler/OperationBlocks.h +++ b/TensorFrost/include/Compiler/OperationBlocks.h @@ -2,38 +2,43 @@ #include "Operation.h" namespace TensorFrost { - struct OpBlock { + using List = std::list>; + using It = List::iterator; + Op* parent_op = nullptr; - std::list> ops; - Op* append(std::unique_ptr op); + List ops; OpBlock(Op* parent = nullptr); -}; -class OpBlockIterator { -public: - using OpIter = std::list>::iterator; - using OpRevIter = std::list>::reverse_iterator; + class Iterator { + OpBlock* parent_; + List* list_; + It cur_; - struct Frame { - OpBlock* block; - OpIter it; - OpIter end; - }; + public: + Iterator(OpBlock *parent, List::iterator it); + + Op* operator*() const; + Op* operator->() const; + Op* get_next() const; + Op* get_prev() const; - OpBlockIterator(OpBlock* root); + Iterator& next(); + Iterator& prev(); + Iterator& insert_after(std::unique_ptr op); + Iterator& insert_before(std::unique_ptr op); - Op* next(); // Move to next Op in depth-first order - Op* prev(); // Move to previous Op in depth-first order - bool down(); // Enter the first sub-block of current Op (if any) - bool up(); // Exit to parent block + OpBlock* parent() const { return parent_; } - Op* current() const; + bool valid() const; + bool operator==(const Iterator& o) const; + bool operator!=(const Iterator& o) const; + }; -private: - std::vector stack; - Op* current_op; + Iterator begin(); + Iterator end(); }; -} \ No newline at end of file +void ApplyOpTransform(OpBlock& block, const std::function& transform); +} diff --git a/TensorFrost/include/Compiler/OperationRegistry.h b/TensorFrost/include/Compiler/OperationRegistry.h index 6b72260e..1cf138c5 100644 --- a/TensorFrost/include/Compiler/OperationRegistry.h +++ b/TensorFrost/include/Compiler/OperationRegistry.h @@ -8,8 +8,9 @@ using OverloadsMap = std::unordered_map, TFDataFormat, struct OpSpec { std::string name; OverloadsMap overloads; + int blocks = 0; - OpSpec(std::string op_name, OverloadsMap overloads_list); + OpSpec(std::string op_name, OverloadsMap overloads_list, int block_count = 0); TFDataFormat GetOutputType(const std::vector& args) const; }; diff --git a/TensorFrost/include/Compiler/Overloads.h b/TensorFrost/include/Compiler/Overloads.h index bbc84d7d..ffdf312e 100644 --- a/TensorFrost/include/Compiler/Overloads.h +++ b/TensorFrost/include/Compiler/Overloads.h @@ -2,16 +2,8 @@ #include "Operation.h" namespace TensorFrost { - Op& make_op(std::string op, std::vector mem, std::vector ids, std::vector args, std::vector shape); - -template -Op &func_op(const std::string &name, const Ts &... args) { - std::vector v; - v.reserve(sizeof...(args)); - (v.push_back(const_cast(&args)), ...); - return make_op(name, {}, {}, v, {}); -} +Op& func_op(const std::string& name, std::vector args = {}, std::vector shape = {}); Op& constant(int value); Op& constant(uint value); @@ -22,50 +14,50 @@ template concept Num = std::is_arithmetic_v>; template concept IsOp = std::same_as, Op>; template -inline Op& as_op(T v) { +inline Op* as_op(T v) { using D = std::remove_cvref_t; using Target = std::conditional_t, bool, std::conditional_t, float, std::conditional_t, unsigned int, int>>>; - return constant(static_cast(v)); + return &constant(static_cast(v)); } -inline Op& as_op(const Op& x) { return const_cast(x); } +inline Op* as_op(const Op& x) { return &const_cast(x); } #define UNARY_OPERATOR(op, opname) \ template \ requires IsOp \ inline Op& operator op(const T& x) { \ - return func_op(opname, as_op(x)); \ + return func_op(opname, {as_op(x)}); \ } #define BINARY_OPERATOR(op, opname) \ template \ requires (IsOp || IsOp) \ inline Op& operator op(const T& x, const U& y) { \ - return func_op(opname, as_op(x), as_op(y)); \ + return func_op(opname, {as_op(x), as_op(y)}); \ } #define UNARY_FUNCTION(func, opname) \ template \ requires IsOp \ inline Op& func(const T& x) { \ - return func_op(opname, as_op(x)); \ + return func_op(opname, {as_op(x)}); \ } #define BINARY_FUNCTION(func, opname) \ template \ requires (IsOp || IsOp) \ inline Op& func(const T& x, const U& y) { \ - return func_op(opname, as_op(x), as_op(y)); \ + return func_op(opname, {as_op(x), as_op(y)}); \ } #define TERNARY_FUNCTION(func, opname) \ template \ requires (IsOp || IsOp || IsOp) \ inline Op& func(const T& x, const U& y, const V& z) { \ - return func_op(opname, as_op(x), as_op(y), as_op(z)); \ + return func_op(opname, {as_op(x), as_op(y), as_op(z)}); \ } UNARY_OPERATOR(+, "pos") @@ -145,4 +137,5 @@ TERNARY_FUNCTION(smoothstep, "smoothstep") TERNARY_FUNCTION(select, "ternary") TERNARY_FUNCTION(fma, "fma") +Op& vmap(std::vector shape, std::function body); } diff --git a/TensorFrost/include/Compiler/Printer.h b/TensorFrost/include/Compiler/Printer.h index 4d24dcb2..79ca40df 100644 --- a/TensorFrost/include/Compiler/Printer.h +++ b/TensorFrost/include/Compiler/Printer.h @@ -5,5 +5,7 @@ namespace TensorFrost { void PrintOp(const Op& op, std::ostream& os); -std::string PrintTree(const OpBlock& base_block); +std::string PrintBlock(OpBlock& base_block); +void AssignVariableNames(OpBlock &block); + } diff --git a/TensorFrost/src/Compiler/ExecutionContext.cpp b/TensorFrost/src/Compiler/ExecutionContext.cpp index dd49b740..825bb9b9 100644 --- a/TensorFrost/src/Compiler/ExecutionContext.cpp +++ b/TensorFrost/src/Compiler/ExecutionContext.cpp @@ -3,25 +3,31 @@ #include "Compiler/OperationBlocks.h" namespace TensorFrost { -ExecutionContext::ExecutionContext(): base_block(std::make_unique()), current_block(base_block.get()) {} +ExecutionContext::ExecutionContext(): base_block(std::make_unique()), cursor(base_block->begin()) {} -void ExecutionContext::BeginBlock(Op *op) { - stack.push_back(current_block); - current_block = op->NewBlock(); +void ExecutionContext::BeginCursor(OpBlock::Iterator it) { + stack.push(&cursor); + cursor = it; } -void ExecutionContext::EndBlock() { - if (!stack.empty()) { - current_block = stack.back(); - stack.pop_back(); - } else { - throw std::runtime_error("No block to end"); +void ExecutionContext::EndCursor() { + if (stack.empty()) { + throw std::runtime_error("This is the last cursor, cannot end it"); } + cursor = *stack.top(); + stack.pop(); } -Op &ExecutionContext::AddOp(std::unique_ptr op) { - current_block->append(std::move(op)); - return *current_block->ops.back(); +Op& ExecutionContext::Add(std::unique_ptr op) { + cursor.insert_before(std::move(op)); + Op* new_op = cursor.get_next(); + cursor.next(); // Move cursor to the new op + return *new_op; +} + +Op& ExecutionContext::AddBeforeCursor(std::unique_ptr op) { + cursor.insert_before(std::move(op)); + return **cursor; } ExecutionContext* current_context = nullptr; @@ -37,6 +43,46 @@ ExecutionContext* GetContext() { return current_context; } +OpBlock* GetBaseBlock() { + if (!current_context) { + throw std::runtime_error("No execution context available"); + } + return current_context->base_block.get(); +} + +OpBlock* GetCurrentBlock() { + if (!current_context) { + throw std::runtime_error("No execution context available"); + } + return current_context->cursor.parent(); +} + +void BeginCursor(OpBlock::Iterator it) { + GetContext()->BeginCursor(it); +} + +void BeginCursor(OpBlock& block) { + GetContext()->BeginCursor(block.begin()); +} + +void BeginCursor(Op* op) { + if (!op || !op->parent_block) { + throw std::runtime_error("Op does not belong to a block"); + } + OpBlock::Iterator it(op->parent_block, op->parent_block->ops.begin()); + // Find the iterator for the specific op + for (; it.valid(); it.next()) { + if (*it == op) { + GetContext()->BeginCursor(it); + return; + } + } +} + +void EndCursor() { + GetContext()->EndCursor(); +} + void EndExecutionContext() { if (!current_context) { throw std::runtime_error("No execution context to end"); diff --git a/TensorFrost/src/Compiler/OperationArguments.cpp b/TensorFrost/src/Compiler/OperationArguments.cpp index bf3782d8..f5da28b8 100644 --- a/TensorFrost/src/Compiler/OperationArguments.cpp +++ b/TensorFrost/src/Compiler/OperationArguments.cpp @@ -6,6 +6,14 @@ void Arguments::AddInput(ArgType type, Op *from, int index) { from->args->SetAsOutput(inputs[index].get()); } +void Arguments::RemoveInput(int index) { + if (index < 0 || index >= inputs.size()) return; + Argument *arg = inputs[index].get(); + if(!arg || !arg->from) throw std::runtime_error("Invalid argument"); + arg->from->args->RemoveOutput(arg); + inputs[index].reset(); +} + bool Arguments::CheckValidity(bool throw_error) const { for (const auto& input : inputs) { if (!input || !input->from) { @@ -18,6 +26,31 @@ bool Arguments::CheckValidity(bool throw_error) const { return true; } +std::vector Arguments::Args() const { + std::vector result; + for (const auto& arg : inputs) { + if (arg) { + result.push_back(arg.get()); + } + } + return result; +} + +std::vector Arguments::Inputs() const { + std::vector result; + for (const auto& arg : inputs) { + if (arg && arg->from) { + result.push_back(arg->from); + } + } + return result; +} + +bool ShapeArgs::CompareShape(const ShapeArgs &other, bool throw_error) const { + //TODO: Implement shape comparison logic + return true; // Placeholder +} + ArgumentManager::ArgumentManager(Op *parent): parent_op(parent) { for (int i = 0; i < (int)ArgType::Shape; ++i) { type_args[i] = std::make_unique(); @@ -35,13 +68,21 @@ void ArgumentManager::SetAsOutput(Argument *arg) { type_args[(int)arg->type]->used_at.insert({arg->index, arg}); } +void ArgumentManager::RemoveOutput(Argument *arg) { + type_args[(int)arg->type]->used_at.erase({arg->index, arg}); +} + void ArgumentManager::SetArguments(ArgType type, std::vector args) { for (size_t i = 0; i < args.size(); ++i) { AddArgument(args[i], type, (int)i); } } -Arguments * ArgumentManager::GetArguments(ArgType type) const { +Arguments* ArgumentManager::Get(ArgType type) const { return type_args[(int)type].get(); } + +Arguments * ArgumentManager::operator[](ArgType type) const { + return Get(type); +} } diff --git a/TensorFrost/src/Compiler/OperationBlocks.cpp b/TensorFrost/src/Compiler/OperationBlocks.cpp index d53072d0..a9617dde 100644 --- a/TensorFrost/src/Compiler/OperationBlocks.cpp +++ b/TensorFrost/src/Compiler/OperationBlocks.cpp @@ -1,85 +1,53 @@ #include "Compiler/Operation.h" namespace TensorFrost { -Op* OpBlock::append(std::unique_ptr op) { - ops.emplace_back(std::move(op)); - return ops.back().get(); -} - OpBlock::OpBlock(Op *parent): parent_op(parent) {} -OpBlockIterator::OpBlockIterator(OpBlock* root) : current_op(nullptr) { - if (root && !root->ops.empty()) { - stack.push_back({root, root->ops.begin(), root->ops.end()}); - current_op = stack.back().it->get(); - } +OpBlock::Iterator::Iterator(OpBlock *parent, List::iterator it) + : parent_(parent), list_(&parent->ops), cur_(it) {} + +Op* OpBlock::Iterator::operator*() const { return cur_->get(); } +Op* OpBlock::Iterator::operator->() const { return cur_->get(); } + +Op * OpBlock::Iterator::get_next() const { return cur_->get(); } +Op * OpBlock::Iterator::get_prev() const { + if (cur_ == list_->begin()) return nullptr; + return std::prev(cur_)->get(); } -Op* OpBlockIterator::current() const { - return current_op; +OpBlock::Iterator & OpBlock::Iterator::next() { if (cur_ != list_->end()) ++cur_; return *this; } +OpBlock::Iterator & OpBlock::Iterator::prev() { if (cur_ != list_->begin()) --cur_; return *this; } + +OpBlock::Iterator& OpBlock::Iterator::insert_after(std::unique_ptr op) { + if (op->parent_block) throw std::runtime_error("Op already belongs to a block"); + + auto pos = (cur_ == list_->end()) ? list_->end() : std::next(cur_); + cur_ = list_->insert(pos, std::move(op)); // <- after-cursor + cur_->get()->parent_block = parent_; + return *this; } -Op* OpBlockIterator::next() { - if (stack.empty()) return nullptr; - // If current op has sub-blocks, go down - if (!current_op->blocks.empty() && current_op->blocks[0] && !current_op->blocks[0]->ops.empty()) { - OpBlock* sub = current_op->blocks[0].get(); - stack.push_back({sub, sub->ops.begin(), sub->ops.end()}); - current_op = stack.back().it->get(); - return current_op; - } - // Otherwise, go to next op in current block or up - while (!stack.empty()) { - auto& frame = stack.back(); - ++frame.it; - if (frame.it != frame.end) { - current_op = frame.it->get(); - return current_op; - } else { - stack.pop_back(); - } - } - current_op = nullptr; - return nullptr; +OpBlock::Iterator& OpBlock::Iterator::insert_before(std::unique_ptr op) { + if (op->parent_block) throw std::runtime_error("Op already belongs to a block"); + + cur_ = list_->insert(cur_, std::move(op)); // <- before-cursor + cur_->get()->parent_block = parent_; + return *this; } -Op* OpBlockIterator::prev() { - if (stack.empty()) return nullptr; - auto& frame = stack.back(); - if (frame.it == frame.block->ops.begin()) { - stack.pop_back(); - if (!stack.empty()) { - current_op = stack.back().it->get(); - return current_op; +bool OpBlock::Iterator::valid() const { return cur_ != list_->end(); } +bool OpBlock::Iterator::operator==(const Iterator &o) const { return cur_ == o.cur_; } +bool OpBlock::Iterator::operator!=(const Iterator &o) const { return cur_ != o.cur_; } + +OpBlock::Iterator OpBlock::begin() { return Iterator(this, ops.begin()); } +OpBlock::Iterator OpBlock::end() { return Iterator(this, ops.end()); } + +void ApplyOpTransform(OpBlock &block, const std::function &transform) { + for (auto& op : block.ops) { + for (auto& sub_block : op->blocks) { + ApplyOpTransform(*sub_block, transform); } - current_op = nullptr; - return nullptr; + transform(*op); } - --frame.it; - // Go to the deepest last op in sub-blocks if any - Op* op = frame.it->get(); - while (!op->blocks.empty() && op->blocks[0] && !op->blocks[0]->ops.empty()) { - OpBlock* sub = op->blocks[0].get(); - stack.push_back({sub, --sub->ops.end(), sub->ops.end()}); - op = stack.back().it->get(); - } - current_op = op; - return current_op; -} - -bool OpBlockIterator::down() { - if (!current_op || current_op->blocks.empty() || !current_op->blocks[0] || current_op->blocks[0]->ops.empty()) - return false; - OpBlock* sub = current_op->blocks[0].get(); - stack.push_back({sub, sub->ops.begin(), sub->ops.end()}); - current_op = stack.back().it->get(); - return true; } - -bool OpBlockIterator::up() { - if (stack.size() <= 1) return false; - stack.pop_back(); - current_op = stack.back().it->get(); - return true; } -} \ No newline at end of file diff --git a/TensorFrost/src/Compiler/OperationRegistry.cpp b/TensorFrost/src/Compiler/OperationRegistry.cpp index b04f4ae6..88d08e18 100644 --- a/TensorFrost/src/Compiler/OperationRegistry.cpp +++ b/TensorFrost/src/Compiler/OperationRegistry.cpp @@ -3,9 +3,10 @@ using namespace std; namespace TensorFrost { -OpSpec::OpSpec(std::string op_name, OverloadsMap overloads_list) { +OpSpec::OpSpec(std::string op_name, OverloadsMap overloads_list, int block_count) { name = std::move(op_name); overloads = std::move(overloads_list); + blocks = block_count; } TFDataFormat OpSpec::GetOutputType(const std::vector &args) const { @@ -63,7 +64,7 @@ vector default_operations = { OpSpec("mul", ovr("f(f,f); u(u,u); i(i,i)")), OpSpec("div", ovr("f(f,f); u(u,u); i(i,i)")), - OpSpec("parallel", ovr("tuple()")), + OpSpec("vmap", ovr("tuple()"), 1), }; std::unordered_map> CreateOperationRegistry() { diff --git a/TensorFrost/src/Compiler/Overloads.cpp b/TensorFrost/src/Compiler/Overloads.cpp index cfdc4b05..e9ebe2a2 100644 --- a/TensorFrost/src/Compiler/Overloads.cpp +++ b/TensorFrost/src/Compiler/Overloads.cpp @@ -17,8 +17,23 @@ Op& make_op(std::string op, std::vector mem, std::vector ids, std::vec op_instance->args->SetArguments(ArgType::Memory, mem); op_instance->args->SetArguments(ArgType::Index, ids); op_instance->args->SetArguments(ArgType::Input, args); + + if(shape.empty()) { + shape = op_instance->args->Get(ArgType::Shape)->Inputs(); + } + op_instance->args->SetArguments(ArgType::Shape, shape); - return GetContext()->AddOp(std::unique_ptr(op_instance)); + + // Create blocks + for (int i = 0; i < spec->blocks; ++i) { + op_instance->NewBlock(); + } + + return GetContext()->Add(std::unique_ptr(op_instance)); +} + +Op & func_op(const std::string &name, std::vector args, std::vector shape) { + return make_op(name, {}, {}, std::move(args), std::move(shape)); } Op& constant(int value) { @@ -48,4 +63,12 @@ Op& constant(bool value) { const_op.type = TFTypeBool32; return const_op; } -} \ No newline at end of file + +Op & vmap(std::vector shape, std::function body) { + Op& par_op = func_op("vmap", {}, shape); + GetContext()->BeginCursor(par_op.blocks.front()->begin()); + body(&par_op); + GetContext()->EndCursor(); + return par_op; +} +} diff --git a/TensorFrost/src/Compiler/Printer.cpp b/TensorFrost/src/Compiler/Printer.cpp index fc4a97a8..9b3fa65e 100644 --- a/TensorFrost/src/Compiler/Printer.cpp +++ b/TensorFrost/src/Compiler/Printer.cpp @@ -5,38 +5,93 @@ using namespace std; namespace TensorFrost { -void PrintOp(const Op &op, std::ostringstream &os) { - os << "Op: " << op.opcode << "\n"; - os << "Type: " << ToString(op.type) << "\n"; - os << "Arguments:\n"; - for (int i = 0; i < (int)ArgType::Count; ++i) { - const auto args = op.args->GetArguments(static_cast(i)); - if (args) { - os << " " << ToString(static_cast(i)) << ":\n"; - for (const auto& arg : args->inputs) { - if (arg) { - os << " From: " << (arg->from ? arg->from->opcode : "nullptr") - << ", Index: " << arg->index << "\n"; - } + +std::string VariableName(const Op* op) { + if (op->opcode == "const") { + return ToString(op->attributes.at("value")); + } + return op->varname; +} + +bool PrintArguments(const auto_vector>& vec, std::ostringstream &os, string begin, string end) { + if (vec.empty()) return false; + os << begin; + bool first = true; + for (const auto& v : vec) { + if (!first) os << ", "; + first = false; + os << VariableName(v->from); + } + os << end; + return true; +} + +void PrintOp(const Op* op, std::ostringstream &os) { + os << ToString(op->type) << " " << op->varname; + PrintArguments(op->args->Get(ArgType::Shape)->inputs, os, "[", "]"); + + if (op->opcode == "const") { + os << " = " << op->attributes.at("value"); + return; + } else { + os << " = " << op->opcode << "("; + // Print inputs + bool printed = PrintArguments(op->args->Get(ArgType::Input)->inputs, os, "", ""); + printed |= PrintArguments(op->args->Get(ArgType::Index)->inputs, os, ", index={", "}"); + printed |= PrintArguments(op->args->Get(ArgType::Memory)->inputs, os, ", memory={", "}"); + if (!op->attributes.empty()) { + if (printed) os << ", "; + os <<"{"; + bool first = true; + for (const auto& [key, value] : op->attributes) { + if (!first) os << ", "; + first = false; + os << key << ": "; + std::visit([&os](const auto& v) { os << v; }, value); } + os << "}"; } + os <<")"; } - os << "Attributes:\n"; - for (const auto& [key, value] : op.attributes) { - os << " " << key << ": "; - std::visit([&os](const auto& v) { os << v; }, value); - os << "\n"; +} + +std::string AddIndent(const std::string& str, int indent) { + // Add indentation to each line of the string + std::string indented; + std::istringstream iss(str); + std::string line; + while (std::getline(iss, line)) { + indented += std::string(indent, ' ') + line + "\n"; } + return indented; } -std::string PrintTree(const OpBlock &base_block) { - OpBlockIterator it(const_cast(&base_block)); - std::ostringstream oss; - while (Op* op = it.next()) { - PrintOp(*op, oss); +std::string PrintBlock(OpBlock &block) { + auto oss = std::ostringstream(); + for (auto it = block.begin(); it.valid(); it.next()) { + PrintOp(*it, oss); + if(it->blocks.size() > 0) { + bool first = true; + oss << " {\n"; + for (auto& sub_block : it->blocks) { + if (!first) oss << "{\n"; + oss< Date: Mon, 9 Jun 2025 02:22:14 +0200 Subject: [PATCH 08/44] Working on IR --- TensorFrost/PybindModule.cpp | 11 ++++- TensorFrost/include/Compiler/Common.h | 3 -- TensorFrost/include/Compiler/Operation.h | 8 ++++ .../include/Compiler/OperationArguments.h | 12 +---- .../include/Compiler/OperationRegistry.h | 29 +++++++++++- TensorFrost/include/Compiler/Overloads.h | 9 ++-- TensorFrost/src/Compiler/Operation.cpp | 29 ++++++++++++ .../src/Compiler/OperationArguments.cpp | 15 ++---- .../src/Compiler/OperationRegistry.cpp | 44 +++++++++++++---- TensorFrost/src/Compiler/Overloads.cpp | 47 ++++++++++++++----- TensorFrost/src/Compiler/Printer.cpp | 5 +- 11 files changed, 157 insertions(+), 55 deletions(-) diff --git a/TensorFrost/PybindModule.cpp b/TensorFrost/PybindModule.cpp index b5133c99..6b791f36 100644 --- a/TensorFrost/PybindModule.cpp +++ b/TensorFrost/PybindModule.cpp @@ -137,8 +137,15 @@ PYBIND11_MODULE(TensorFrost, m) { Op& a = constant(5); Op& b = constant(10); Op& c = a + b * 3; - vmap({&a, &b, &c}, [&](Op* op) { - Op& d = c + b; + Op& mem = memory({&a, &b, &c}, TFTypeFloat32); + mem.AddAttribute("program_input", 0); //is the first argument to the program + vmap({&a, &b, &c}, [&](Op& ids0) { + Op& imem = toint(mem); + Op& d = c + b + ids0[0] * imem; + vmap({&c, &c}, [&](Op& ids1) { + Op& m = d * c * imem[{&ids1[1], &ids1[0], &ids0[0]}]; + m.AddAttribute("program_output", 0); //is the first output of the program + }); }); AssignVariableNames(*GetBaseBlock()); std::string tree = PrintBlock(*GetBaseBlock()); diff --git a/TensorFrost/include/Compiler/Common.h b/TensorFrost/include/Compiler/Common.h index 69e37d64..0b7969f0 100644 --- a/TensorFrost/include/Compiler/Common.h +++ b/TensorFrost/include/Compiler/Common.h @@ -69,7 +69,6 @@ enum class ArgType { Input, Index, Memory, - Shape, //must be last Count, }; @@ -77,7 +76,6 @@ inline std::string ToString(ArgType type) { switch (type) { case ArgType::Input: return "Input"; case ArgType::Index: return "Index"; - case ArgType::Shape: return "Shape"; case ArgType::Memory: return "Memory"; default: return "Unknown"; } @@ -102,7 +100,6 @@ struct Arguments; struct OpBlock; class OpBlockIterator; struct ArgumentManager; -struct ShapeArgs; struct Argument; using Attribute = std::variant; diff --git a/TensorFrost/include/Compiler/Operation.h b/TensorFrost/include/Compiler/Operation.h index 5e2eaf36..42824d61 100644 --- a/TensorFrost/include/Compiler/Operation.h +++ b/TensorFrost/include/Compiler/Operation.h @@ -22,6 +22,14 @@ struct Op { Op(std::string op_name); OpBlock* NewBlock(); + OpBlock& GetBlock(int index = 0); + + Op& operator[](int index); + Op& operator[](std::vector indices); + + void AddAttribute(const std::string& name, const Attribute& value); + void ChangeAttribute(const std::string& name, const Attribute& value); + void GetAttribute(const std::string& name, Attribute& value) const; }; diff --git a/TensorFrost/include/Compiler/OperationArguments.h b/TensorFrost/include/Compiler/OperationArguments.h index 01e304d4..66083aaf 100644 --- a/TensorFrost/include/Compiler/OperationArguments.h +++ b/TensorFrost/include/Compiler/OperationArguments.h @@ -23,23 +23,15 @@ struct Arguments { std::vector Inputs() const; }; -struct ShapeArgs : Arguments { - std::vector TryGetShape(int default_value = 256) const; - float GetSizeEstimate(); - void ExpandDimensionsTo(int new_dim); - - bool CompareShape(const ShapeArgs& other, bool throw_error = false) const; -}; - struct ArgumentManager { Op* parent_op = nullptr; std::array, (int)ArgType::Count> type_args; ArgumentManager(Op* parent); - void AddArgument(Op* from, ArgType type, int index = 0); + void AddArgument(Op &from, ArgType type, int index = 0); void SetAsOutput(Argument *arg); void RemoveOutput(Argument *arg); - void SetArguments(ArgType type, std::vector args); + void SetArguments(ArgType type, std::vector args); Arguments* Get(ArgType type) const; Arguments* operator[](ArgType type) const; diff --git a/TensorFrost/include/Compiler/OperationRegistry.h b/TensorFrost/include/Compiler/OperationRegistry.h index 1cf138c5..79dab346 100644 --- a/TensorFrost/include/Compiler/OperationRegistry.h +++ b/TensorFrost/include/Compiler/OperationRegistry.h @@ -5,12 +5,39 @@ namespace TensorFrost { using OverloadsMap = std::unordered_map, TFDataFormat, VecHash>; +enum class OpClass { + Operator, + UnaryOperator, + Function, + Copy, + Keyword, + Parallel, + Variable, + TypeCast, + TypeReinterpret, + Constant, + TernaryOperator, + Memory, + None, +}; + +enum class OpProp { + HasShape, + Load, + Store, + MemoryOp, + Set, +}; + struct OpSpec { std::string name; OverloadsMap overloads; + OpClass op_class = OpClass::None; + std::set properties; int blocks = 0; - OpSpec(std::string op_name, OverloadsMap overloads_list, int block_count = 0); + OpSpec(std::string op_name, OverloadsMap overloads_list, int block_count = 0, + OpClass op_class_type = OpClass::None, std::set props = {}); TFDataFormat GetOutputType(const std::vector& args) const; }; diff --git a/TensorFrost/include/Compiler/Overloads.h b/TensorFrost/include/Compiler/Overloads.h index ffdf312e..22f23ddf 100644 --- a/TensorFrost/include/Compiler/Overloads.h +++ b/TensorFrost/include/Compiler/Overloads.h @@ -2,8 +2,8 @@ #include "Operation.h" namespace TensorFrost { -Op& make_op(std::string op, std::vector mem, std::vector ids, std::vector args, std::vector shape); -Op& func_op(const std::string& name, std::vector args = {}, std::vector shape = {}); +Op& make_op(std::string op, std::vector mem, std::vector ids, std::vector args); +Op& func_op(const std::string& name, std::vector args = {}); Op& constant(int value); Op& constant(uint value); @@ -137,5 +137,8 @@ TERNARY_FUNCTION(smoothstep, "smoothstep") TERNARY_FUNCTION(select, "ternary") TERNARY_FUNCTION(fma, "fma") -Op& vmap(std::vector shape, std::function body); +Op& unpack_tuple(const Op& x, int index = 0); +Op& vmap(std::vector shape, std::function body); +Op& memory(std::vector shape, TFDataFormat type); +Op& load_at_index(const Op& mem, std::vector indices); } diff --git a/TensorFrost/src/Compiler/Operation.cpp b/TensorFrost/src/Compiler/Operation.cpp index bcd0a31f..efe95b59 100644 --- a/TensorFrost/src/Compiler/Operation.cpp +++ b/TensorFrost/src/Compiler/Operation.cpp @@ -10,4 +10,33 @@ OpBlock* Op::NewBlock() { blocks.emplace_back(std::make_unique(this)); return blocks.back().get(); } + +OpBlock& Op::GetBlock(int index) { + if (index < 0 || index >= blocks.size()) { + throw std::out_of_range("Block index out of range"); + } + return *blocks[index]; +} + +void Op::AddAttribute(const std::string &name, const Attribute &value) { + if (attributes.find(name) != attributes.end()) { + throw std::runtime_error("Attribute '" + name + "' already exists in operation '" + opcode + "'"); + } + attributes[name] = value; +} + +void Op::ChangeAttribute(const std::string &name, const Attribute &value) { + if (attributes.find(name) == attributes.end()) { + throw std::runtime_error("Attribute '" + name + "' not found in operation '" + opcode + "'"); + } + attributes[name] = value; +} + +void Op::GetAttribute(const std::string &name, Attribute &value) const { + auto it = attributes.find(name); + if (it == attributes.end()) { + throw std::runtime_error("Attribute '" + name + "' not found in operation '" + opcode + "'"); + } + value = it->second; +} } diff --git a/TensorFrost/src/Compiler/OperationArguments.cpp b/TensorFrost/src/Compiler/OperationArguments.cpp index f5da28b8..049334ba 100644 --- a/TensorFrost/src/Compiler/OperationArguments.cpp +++ b/TensorFrost/src/Compiler/OperationArguments.cpp @@ -46,22 +46,15 @@ std::vector Arguments::Inputs() const { return result; } -bool ShapeArgs::CompareShape(const ShapeArgs &other, bool throw_error) const { - //TODO: Implement shape comparison logic - return true; // Placeholder -} - ArgumentManager::ArgumentManager(Op *parent): parent_op(parent) { - for (int i = 0; i < (int)ArgType::Shape; ++i) { + for (int i = 0; i < (int)ArgType::Count; ++i) { type_args[i] = std::make_unique(); type_args[i]->parent_op = parent; } - type_args[(int)ArgType::Shape] = std::make_unique(); - type_args[(int)ArgType::Shape]->parent_op = parent; } -void ArgumentManager::AddArgument(Op *from, ArgType type, int index) { - type_args[(int)type]->AddInput(type, from, index); +void ArgumentManager::AddArgument(Op &from, ArgType type, int index) { + type_args[(int)type]->AddInput(type, &from, index); } void ArgumentManager::SetAsOutput(Argument *arg) { @@ -74,7 +67,7 @@ void ArgumentManager::RemoveOutput(Argument *arg) { void ArgumentManager::SetArguments(ArgType type, std::vector args) { for (size_t i = 0; i < args.size(); ++i) { - AddArgument(args[i], type, (int)i); + AddArgument(*args[i], type, (int)i); } } diff --git a/TensorFrost/src/Compiler/OperationRegistry.cpp b/TensorFrost/src/Compiler/OperationRegistry.cpp index 88d08e18..cea0f626 100644 --- a/TensorFrost/src/Compiler/OperationRegistry.cpp +++ b/TensorFrost/src/Compiler/OperationRegistry.cpp @@ -3,16 +3,31 @@ using namespace std; namespace TensorFrost { -OpSpec::OpSpec(std::string op_name, OverloadsMap overloads_list, int block_count) { +OpSpec::OpSpec(std::string op_name, OverloadsMap overloads_list, int block_count, OpClass op_class_type, std::set props) { name = std::move(op_name); overloads = std::move(overloads_list); blocks = block_count; + op_class = op_class_type; + properties = std::move(props); } TFDataFormat OpSpec::GetOutputType(const std::vector &args) const { + if (properties.contains(OpProp::HasShape) || args.empty()) { + return overloads.find({})->second; + } auto it = overloads.find(args); if (it == overloads.end()) { - throw std::runtime_error("No overload found for operation: " + name + " with args: " + to_string(args.size())); + std::string error_msg = "No overload found for operation: " + name + " with args: ("; + for (const auto& arg : args) { + error_msg += ToString(arg) + ", "; + } + if (!args.empty()) { + error_msg.pop_back(); // Remove last comma + error_msg.pop_back(); // Remove last space + } + error_msg += ")"; + + throw std::runtime_error(error_msg); } return it->second; } @@ -57,14 +72,27 @@ OverloadsMap ovr(const std::string& input) { } vector default_operations = { - OpSpec("const", ovr("f(); u(); i(); b(); tuple()")), + OpSpec("memory", ovr("f(); u(); i(); b(); tuple()"), 0, OpClass::Memory, {OpProp::HasShape}), + OpSpec("const", ovr("f(); u(); i(); b(); tuple()"), 0, OpClass::Constant), + + OpSpec("add", ovr("f(f,f); u(u,u); i(i,i)"), 0, OpClass::Operator), + OpSpec("sub", ovr("f(f,f); u(u,u); i(i,i)"), 0, OpClass::Operator), + OpSpec("mul", ovr("f(f,f); u(u,u); i(i,i)"), 0, OpClass::Operator), + OpSpec("div", ovr("f(f,f); u(u,u); i(i,i)"), 0, OpClass::Operator), + + OpSpec("vmap", ovr("tuple()"), 1, OpClass::Parallel, {OpProp::HasShape}), + + OpSpec("unpack_tuple_int", ovr("i(tuple)"), 0, OpClass::Function), + + OpSpec("copy", ovr("f(f); u(u); i(i); b(b)"), 0, OpClass::Copy), - OpSpec("add", ovr("f(f,f); u(u,u); i(i,i)")), - OpSpec("sub", ovr("f(f,f); u(u,u); i(i,i)")), - OpSpec("mul", ovr("f(f,f); u(u,u); i(i,i)")), - OpSpec("div", ovr("f(f,f); u(u,u); i(i,i)")), + OpSpec("tofloat", ovr("f(i); f(u); f(b)"), 0, OpClass::TypeCast), + OpSpec("toint", ovr("i(f); i(u); i(b)"), 0, OpClass::TypeCast), + OpSpec("touint", ovr("u(f); u(i); u(b)"), 0, OpClass::TypeCast), + OpSpec("tobool", ovr("b(f); b(i); b(u)"), 0, OpClass::TypeCast), - OpSpec("vmap", ovr("tuple()"), 1), + OpSpec("load", ovr("f(); u(); i(); b()"), 0, OpClass::Function, {OpProp::Load, OpProp::MemoryOp}), + OpSpec("store", ovr("f(); u(); i(); b()"), 0, OpClass::Function, {OpProp::Store, OpProp::MemoryOp}), }; std::unordered_map> CreateOperationRegistry() { diff --git a/TensorFrost/src/Compiler/Overloads.cpp b/TensorFrost/src/Compiler/Overloads.cpp index e9ebe2a2..e3272c20 100644 --- a/TensorFrost/src/Compiler/Overloads.cpp +++ b/TensorFrost/src/Compiler/Overloads.cpp @@ -5,7 +5,7 @@ using namespace std; namespace TensorFrost { // General function to create an Op instance in the current execution context -Op& make_op(std::string op, std::vector mem, std::vector ids, std::vector args, std::vector shape) { +Op& make_op(std::string op, std::vector mem, std::vector ids, std::vector args) { OpSpec* spec = GetOpSpec(op); vector arg_types; for (const auto& arg : args) { @@ -18,12 +18,6 @@ Op& make_op(std::string op, std::vector mem, std::vector ids, std::vec op_instance->args->SetArguments(ArgType::Index, ids); op_instance->args->SetArguments(ArgType::Input, args); - if(shape.empty()) { - shape = op_instance->args->Get(ArgType::Shape)->Inputs(); - } - - op_instance->args->SetArguments(ArgType::Shape, shape); - // Create blocks for (int i = 0; i < spec->blocks; ++i) { op_instance->NewBlock(); @@ -32,8 +26,8 @@ Op& make_op(std::string op, std::vector mem, std::vector ids, std::vec return GetContext()->Add(std::unique_ptr(op_instance)); } -Op & func_op(const std::string &name, std::vector args, std::vector shape) { - return make_op(name, {}, {}, std::move(args), std::move(shape)); +Op & func_op(const std::string &name, std::vector args) { + return make_op(name, {}, {}, std::move(args)); } Op& constant(int value) { @@ -64,11 +58,38 @@ Op& constant(bool value) { return const_op; } -Op & vmap(std::vector shape, std::function body) { - Op& par_op = func_op("vmap", {}, shape); - GetContext()->BeginCursor(par_op.blocks.front()->begin()); - body(&par_op); +Op& unpack_tuple(const Op &x, int index) { + Op& elem = func_op("unpack_tuple_int", {as_op(x)}); + elem.attributes["index"] = index; // Default index + return elem; +} + +Op& vmap(std::vector shape, std::function body) { + Op& par_op = func_op("vmap", std::move(shape)); + GetContext()->BeginCursor(par_op.GetBlock().begin()); + body(par_op); GetContext()->EndCursor(); return par_op; } + +Op& memory(std::vector shape, TFDataFormat type) { + Op& mem_op = func_op("memory", std::move(shape)); + mem_op.type = type; + return mem_op; +} + +Op& load_at_index(const Op& mem, std::vector indices) { + Op& load_op = make_op("load", {as_op(mem)}, indices, {}); + load_op.type = mem.type; // Assume the loaded type is the same as the memory type + return load_op; +} + +Op& Op::operator[](int index) { + return unpack_tuple(*this, index); +} + +Op& Op::operator[](std::vector indices) { + return load_at_index(*this, std::move(indices)); +} + } diff --git a/TensorFrost/src/Compiler/Printer.cpp b/TensorFrost/src/Compiler/Printer.cpp index 9b3fa65e..89606a0d 100644 --- a/TensorFrost/src/Compiler/Printer.cpp +++ b/TensorFrost/src/Compiler/Printer.cpp @@ -28,11 +28,8 @@ bool PrintArguments(const auto_vector>& vec, std::ostr void PrintOp(const Op* op, std::ostringstream &os) { os << ToString(op->type) << " " << op->varname; - PrintArguments(op->args->Get(ArgType::Shape)->inputs, os, "[", "]"); - if (op->opcode == "const") { os << " = " << op->attributes.at("value"); - return; } else { os << " = " << op->opcode << "("; // Print inputs @@ -77,7 +74,7 @@ std::string PrintBlock(OpBlock &block) { if (!first) oss << "{\n"; oss< Date: Tue, 10 Jun 2025 05:22:50 +0200 Subject: [PATCH 09/44] Better print + TFProgram --- TensorFrost/PybindModule.cpp | 34 ++++----- TensorFrost/include/Compiler/Common.h | 18 ++++- .../include/Compiler/ExecutionContext.h | 2 +- .../include/Compiler/OperationRegistry.h | 46 +++++++++++-- TensorFrost/include/Compiler/Overloads.h | 2 +- TensorFrost/include/Compiler/Printer.h | 5 ++ TensorFrost/include/Compiler/TFProgram.h | 16 +++++ TensorFrost/include/TensorFrost.h | 3 +- TensorFrost/src/Compiler/ExecutionContext.cpp | 8 ++- .../src/Compiler/OperationRegistry.cpp | 64 ++++++++++------- TensorFrost/src/Compiler/Overloads.cpp | 11 ++- TensorFrost/src/Compiler/Printer.cpp | 69 ++++++++++--------- TensorFrost/src/Compiler/TFProgram.cpp | 24 +++++++ 13 files changed, 208 insertions(+), 94 deletions(-) create mode 100644 TensorFrost/include/Compiler/TFProgram.h create mode 100644 TensorFrost/src/Compiler/TFProgram.cpp diff --git a/TensorFrost/PybindModule.cpp b/TensorFrost/PybindModule.cpp index 6b791f36..87d4c228 100644 --- a/TensorFrost/PybindModule.cpp +++ b/TensorFrost/PybindModule.cpp @@ -133,24 +133,26 @@ PYBIND11_MODULE(TensorFrost, m) { #endif // TEST CODE - StartExecutionContext(); - Op& a = constant(5); - Op& b = constant(10); - Op& c = a + b * 3; - Op& mem = memory({&a, &b, &c}, TFTypeFloat32); - mem.AddAttribute("program_input", 0); //is the first argument to the program - vmap({&a, &b, &c}, [&](Op& ids0) { - Op& imem = toint(mem); - Op& d = c + b + ids0[0] * imem; - vmap({&c, &c}, [&](Op& ids1) { - Op& m = d * c * imem[{&ids1[1], &ids1[0], &ids0[0]}]; - m.AddAttribute("program_output", 0); //is the first output of the program + TFProgram program([]() -> auto { + std::vector inputs; + std::vector outputs; + Op& a = constant(5); + Op& b = constant(10); + Op& c = a + b * 3; + Op& mem = memory({&a, &b, &c}, TFTypeFloat32); + inputs.push_back(&mem); + vmap({&a, &b, &c}, [&](Op& ids0) { + Op& imem = toint(mem); + Op& d = c + b + ids0[0] * imem; + vmap({&c, &c}, [&](Op& ids1) { + Op& m = d * c * imem[{&ids1[1], &ids1[0], &ids0[0]}]; + outputs.push_back(&m); + }); }); + return std::make_pair(inputs, outputs); }); - AssignVariableNames(*GetBaseBlock()); - std::string tree = PrintBlock(*GetBaseBlock()); - py::print("Created operation tree:"); - py::print(tree); + + py::print(program.DebugPrint()); } } // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/include/Compiler/Common.h b/TensorFrost/include/Compiler/Common.h index 0b7969f0..62c64299 100644 --- a/TensorFrost/include/Compiler/Common.h +++ b/TensorFrost/include/Compiler/Common.h @@ -11,6 +11,8 @@ #include #include #include +#include +#include namespace TensorFrost { extern "C" { @@ -68,7 +70,6 @@ struct VecHash { enum class ArgType { Input, Index, - Memory, Count, }; @@ -76,7 +77,6 @@ inline std::string ToString(ArgType type) { switch (type) { case ArgType::Input: return "Input"; case ArgType::Index: return "Index"; - case ArgType::Memory: return "Memory"; default: return "Unknown"; } } @@ -102,8 +102,9 @@ class OpBlockIterator; struct ArgumentManager; struct Argument; -using Attribute = std::variant; +using Attribute = std::variant; using AttributeMap = std::unordered_map; +using AttributeSpan = std::span; //ostringstream conversion for Attribute inline std::ostream& operator<<(std::ostream& os, const Attribute& attr) { @@ -117,7 +118,18 @@ inline std::string ToString(const Attribute& attr) { return oss.str(); } +template +auto TransformVector(const Container& input, Func func) { + using T2 = decltype(func(*std::begin(input))); + std::vector output; + output.reserve(input.size()); + for (const auto& item : input) { + output.push_back(func(item)); + } + return output; +} } + namespace std { template<> struct hash { diff --git a/TensorFrost/include/Compiler/ExecutionContext.h b/TensorFrost/include/Compiler/ExecutionContext.h index 55805a9d..15ec02e8 100644 --- a/TensorFrost/include/Compiler/ExecutionContext.h +++ b/TensorFrost/include/Compiler/ExecutionContext.h @@ -18,7 +18,7 @@ struct ExecutionContext { Op &AddBeforeCursor(std::unique_ptr op); }; -void StartExecutionContext(); +void StartExecutionContext(ExecutionContext* ctx); ExecutionContext* GetContext(); OpBlock* GetBaseBlock(); OpBlock* GetCurrentBlock(); diff --git a/TensorFrost/include/Compiler/OperationRegistry.h b/TensorFrost/include/Compiler/OperationRegistry.h index 79dab346..55ebd860 100644 --- a/TensorFrost/include/Compiler/OperationRegistry.h +++ b/TensorFrost/include/Compiler/OperationRegistry.h @@ -22,22 +22,58 @@ enum class OpClass { }; enum class OpProp { - HasShape, + ShapeArgs, Load, Store, MemoryOp, Set, }; +using FoldFn = std::function; + +[[noreturn]] inline void bad_arity(std::size_t expect, std::size_t got) +{ + throw std::invalid_argument("constant-fold expects " + std::to_string(expect) + + " operands, got " + std::to_string(got)); +} + +template +FoldFn make_fold1(F f) { + return [f = std::move(f)](AttributeSpan a) -> Attribute { + if (a.size() != 1) bad_arity(1, a.size()); + return std::visit([&](auto &&x) -> Attribute { + return f(std::forward(x)); + }, a[0]); + }; +} + +template +FoldFn make_fold2(F f) { + return [f = std::move(f)](AttributeSpan a) -> Attribute { + if (a.size() != 2) bad_arity(2, a.size()); + return std::visit([&](auto &&x, auto &&y) -> Attribute { + return f(std::forward(x), std::forward(y)); + }, a[0], a[1]); + }; +} + +template +FoldFn make_fold3(F f) { + return [f = std::move(f)](AttributeSpan a) -> Attribute { + if (a.size() != 3) bad_arity(3, a.size()); + return std::visit([&](auto &&x, auto &&y, auto &&z) -> Attribute { + return f(std::forward(x), std::forward(y), std::forward(z)); + }, a[0], a[1], a[2]); + }; +} + struct OpSpec { std::string name; OverloadsMap overloads; OpClass op_class = OpClass::None; - std::set properties; + std::set props; int blocks = 0; - - OpSpec(std::string op_name, OverloadsMap overloads_list, int block_count = 0, - OpClass op_class_type = OpClass::None, std::set props = {}); + FoldFn constant_fold = nullptr; TFDataFormat GetOutputType(const std::vector& args) const; }; diff --git a/TensorFrost/include/Compiler/Overloads.h b/TensorFrost/include/Compiler/Overloads.h index 22f23ddf..38c2624a 100644 --- a/TensorFrost/include/Compiler/Overloads.h +++ b/TensorFrost/include/Compiler/Overloads.h @@ -2,7 +2,7 @@ #include "Operation.h" namespace TensorFrost { -Op& make_op(std::string op, std::vector mem, std::vector ids, std::vector args); +Op& make_op(std::string op, std::vector ids, std::vector args); Op& func_op(const std::string& name, std::vector args = {}); Op& constant(int value); diff --git a/TensorFrost/include/Compiler/Printer.h b/TensorFrost/include/Compiler/Printer.h index 79ca40df..8a1b7a8d 100644 --- a/TensorFrost/include/Compiler/Printer.h +++ b/TensorFrost/include/Compiler/Printer.h @@ -4,8 +4,13 @@ namespace TensorFrost { +std::string VariableName(const Op* op); void PrintOp(const Op& op, std::ostream& os); std::string PrintBlock(OpBlock& base_block); void AssignVariableNames(OpBlock &block); +std::string PrintAttribute(Attribute attr); +std::string AddIndent(const std::string& str, int indent); +std::string PrintArray(std::vector items, const std::string& begin = "", const std::string& end = "", + const std::string& separator = ", "); } diff --git a/TensorFrost/include/Compiler/TFProgram.h b/TensorFrost/include/Compiler/TFProgram.h new file mode 100644 index 00000000..750244dd --- /dev/null +++ b/TensorFrost/include/Compiler/TFProgram.h @@ -0,0 +1,16 @@ +#pragma once +#include "Operation.h" +#include "ExecutionContext.h" +#include "Printer.h" + +namespace TensorFrost { +class TFProgram { +public: + ExecutionContext context; + std::vector program_inputs; + std::vector program_outputs; + + TFProgram(std::function, std::vector>()> program_fn); + std::string DebugPrint() const; +}; +} \ No newline at end of file diff --git a/TensorFrost/include/TensorFrost.h b/TensorFrost/include/TensorFrost.h index 4fa8ac4d..8ce5653f 100644 --- a/TensorFrost/include/TensorFrost.h +++ b/TensorFrost/include/TensorFrost.h @@ -5,4 +5,5 @@ #include "Compiler/OperationBlocks.h" #include "Compiler/OperationArguments.h" #include "Compiler/Overloads.h" -#include "Compiler/Printer.h" \ No newline at end of file +#include "Compiler/Printer.h" +#include "Compiler/TFProgram.h" \ No newline at end of file diff --git a/TensorFrost/src/Compiler/ExecutionContext.cpp b/TensorFrost/src/Compiler/ExecutionContext.cpp index 825bb9b9..a0688a07 100644 --- a/TensorFrost/src/Compiler/ExecutionContext.cpp +++ b/TensorFrost/src/Compiler/ExecutionContext.cpp @@ -32,11 +32,14 @@ Op& ExecutionContext::AddBeforeCursor(std::unique_ptr op) { ExecutionContext* current_context = nullptr; -void StartExecutionContext() { +void StartExecutionContext(ExecutionContext* ctx) { if (current_context) { throw std::runtime_error("Execution context already started"); } - current_context = new ExecutionContext(); + if (!ctx) { + throw std::invalid_argument("Execution context cannot be null"); + } + current_context = ctx; } ExecutionContext* GetContext() { @@ -87,7 +90,6 @@ void EndExecutionContext() { if (!current_context) { throw std::runtime_error("No execution context to end"); } - delete current_context; current_context = nullptr; } } \ No newline at end of file diff --git a/TensorFrost/src/Compiler/OperationRegistry.cpp b/TensorFrost/src/Compiler/OperationRegistry.cpp index cea0f626..c4236af3 100644 --- a/TensorFrost/src/Compiler/OperationRegistry.cpp +++ b/TensorFrost/src/Compiler/OperationRegistry.cpp @@ -3,16 +3,8 @@ using namespace std; namespace TensorFrost { -OpSpec::OpSpec(std::string op_name, OverloadsMap overloads_list, int block_count, OpClass op_class_type, std::set props) { - name = std::move(op_name); - overloads = std::move(overloads_list); - blocks = block_count; - op_class = op_class_type; - properties = std::move(props); -} - TFDataFormat OpSpec::GetOutputType(const std::vector &args) const { - if (properties.contains(OpProp::HasShape) || args.empty()) { + if (props.contains(OpProp::ShapeArgs) || args.empty()) { return overloads.find({})->second; } auto it = overloads.find(args); @@ -48,7 +40,7 @@ static std::string trim(std::string_view s) { return std::string{s.substr(a, b - a)}; } -OverloadsMap ovr(const std::string& input) { +OverloadsMap GenerateOverloadMap(const std::string& input) { OverloadsMap out; std::stringstream ss(input); std::string stmt; @@ -71,28 +63,48 @@ OverloadsMap ovr(const std::string& input) { return out; } -vector default_operations = { - OpSpec("memory", ovr("f(); u(); i(); b(); tuple()"), 0, OpClass::Memory, {OpProp::HasShape}), - OpSpec("const", ovr("f(); u(); i(); b(); tuple()"), 0, OpClass::Constant), +#define DEF_OP(op_name, overload_str, operation_class, ...) \ + OpSpec{ .name = op_name, .overloads = GenerateOverloadMap(overload_str), .op_class = operation_class, __VA_ARGS__ } + +#define BIN_OP_FOLD(op) \ + make_fold2([](auto a, auto b) { return a op b; }) + +#define UN_OP_FOLD(op) \ + make_fold1([](auto a) { return op a; }) - OpSpec("add", ovr("f(f,f); u(u,u); i(i,i)"), 0, OpClass::Operator), - OpSpec("sub", ovr("f(f,f); u(u,u); i(i,i)"), 0, OpClass::Operator), - OpSpec("mul", ovr("f(f,f); u(u,u); i(i,i)"), 0, OpClass::Operator), - OpSpec("div", ovr("f(f,f); u(u,u); i(i,i)"), 0, OpClass::Operator), +#define UN_FUNC_FOLD(op) \ + make_fold1([](auto a) { return op(a); }) + +#define BIN_FUNC_FOLD(op) \ + make_fold2([](auto a, auto b) { return op(a, b); }) + +#define TERN_FUNC_FOLD(op) \ + make_fold3([](auto a, auto b, auto c) { return op(a, b, c); }) + +vector default_operations = { + DEF_OP("memory", "f(); u(); i(); b(); tuple()", OpClass::Memory, .props = {OpProp::ShapeArgs}), + DEF_OP("load", "f(f); u(u); i(i); b(b)", OpClass::Function, .props = {OpProp::Load, OpProp::MemoryOp}), + DEF_OP("store", "f(f); u(u); i(i); b(b)", OpClass::Function, .props = {OpProp::Store, OpProp::MemoryOp}), - OpSpec("vmap", ovr("tuple()"), 1, OpClass::Parallel, {OpProp::HasShape}), + DEF_OP("const", "f(); u(); i(); b(); tuple()", OpClass::Constant), + DEF_OP("copy", "f(f); u(u); i(i); b(b)", OpClass::Copy), + DEF_OP("add", "f(f,f); u(u,u); i(i,i)", OpClass::Operator, .constant_fold = BIN_OP_FOLD(+)), + DEF_OP("sub", "f(f,f); u(u,u); i(i,i)", OpClass::Operator, .constant_fold = BIN_OP_FOLD(-)), + DEF_OP("mul", "f(f,f); u(u,u); i(i,i)", OpClass::Operator, .constant_fold = BIN_OP_FOLD(*)), + DEF_OP("div", "f(f,f); u(u,u); i(i,i)", OpClass::Operator, .constant_fold = BIN_OP_FOLD(/)), + DEF_OP("sin", "f(f); u(u); i(i)", OpClass::UnaryOperator, .constant_fold = UN_FUNC_FOLD(std::sinf)), + DEF_OP("cos", "f(f); u(u); i(i)", OpClass::UnaryOperator, .constant_fold = UN_FUNC_FOLD(std::cosf)), + DEF_OP("tan", "f(f); u(u); i(i)", OpClass::UnaryOperator, .constant_fold = UN_FUNC_FOLD(std::tanf)), - OpSpec("unpack_tuple_int", ovr("i(tuple)"), 0, OpClass::Function), - OpSpec("copy", ovr("f(f); u(u); i(i); b(b)"), 0, OpClass::Copy), + DEF_OP("tofloat", "f(i); f(u); f(b)", OpClass::TypeCast), + DEF_OP("toint", "i(f); i(u); i(b)", OpClass::TypeCast), + DEF_OP("touint", "u(f); u(i); u(b)", OpClass::TypeCast), + DEF_OP("tobool", "b(f); b(i); b(u)", OpClass::TypeCast), - OpSpec("tofloat", ovr("f(i); f(u); f(b)"), 0, OpClass::TypeCast), - OpSpec("toint", ovr("i(f); i(u); i(b)"), 0, OpClass::TypeCast), - OpSpec("touint", ovr("u(f); u(i); u(b)"), 0, OpClass::TypeCast), - OpSpec("tobool", ovr("b(f); b(i); b(u)"), 0, OpClass::TypeCast), + DEF_OP("unpack_tuple_int", "i(tuple)", OpClass::Function), - OpSpec("load", ovr("f(); u(); i(); b()"), 0, OpClass::Function, {OpProp::Load, OpProp::MemoryOp}), - OpSpec("store", ovr("f(); u(); i(); b()"), 0, OpClass::Function, {OpProp::Store, OpProp::MemoryOp}), + DEF_OP("vmap", "tuple()", OpClass::Parallel, .props = {OpProp::ShapeArgs}, .blocks = 1), }; std::unordered_map> CreateOperationRegistry() { diff --git a/TensorFrost/src/Compiler/Overloads.cpp b/TensorFrost/src/Compiler/Overloads.cpp index e3272c20..d83bdaab 100644 --- a/TensorFrost/src/Compiler/Overloads.cpp +++ b/TensorFrost/src/Compiler/Overloads.cpp @@ -5,7 +5,7 @@ using namespace std; namespace TensorFrost { // General function to create an Op instance in the current execution context -Op& make_op(std::string op, std::vector mem, std::vector ids, std::vector args) { +Op& make_op(std::string op, std::vector ids, std::vector args) { OpSpec* spec = GetOpSpec(op); vector arg_types; for (const auto& arg : args) { @@ -14,7 +14,6 @@ Op& make_op(std::string op, std::vector mem, std::vector ids, std::vec TFDataFormat output_type = spec->GetOutputType(arg_types); Op* op_instance = new Op(op); op_instance->type = output_type; - op_instance->args->SetArguments(ArgType::Memory, mem); op_instance->args->SetArguments(ArgType::Index, ids); op_instance->args->SetArguments(ArgType::Input, args); @@ -26,8 +25,8 @@ Op& make_op(std::string op, std::vector mem, std::vector ids, std::vec return GetContext()->Add(std::unique_ptr(op_instance)); } -Op & func_op(const std::string &name, std::vector args) { - return make_op(name, {}, {}, std::move(args)); +Op & func_op(const std::string &name, std::vector args) { + return make_op(name, {}, std::move(args)); } Op& constant(int value) { @@ -79,9 +78,7 @@ Op& memory(std::vector shape, TFDataFormat type) { } Op& load_at_index(const Op& mem, std::vector indices) { - Op& load_op = make_op("load", {as_op(mem)}, indices, {}); - load_op.type = mem.type; // Assume the loaded type is the same as the memory type - return load_op; + return make_op("load", indices, {as_op(mem)}); } Op& Op::operator[](int index) { diff --git a/TensorFrost/src/Compiler/Printer.cpp b/TensorFrost/src/Compiler/Printer.cpp index 89606a0d..e3e28400 100644 --- a/TensorFrost/src/Compiler/Printer.cpp +++ b/TensorFrost/src/Compiler/Printer.cpp @@ -13,17 +13,35 @@ std::string VariableName(const Op* op) { return op->varname; } -bool PrintArguments(const auto_vector>& vec, std::ostringstream &os, string begin, string end) { - if (vec.empty()) return false; - os << begin; +std::vector StringifyArguments(const auto_vector>& vec) { + return TransformVector(vec, [](const std::unique_ptr& arg) { + return VariableName(arg->from); + }); +} + +std::string PrintArray(std::vector items, const std::string &begin, const std::string &end, const std::string& separator) { + std::ostringstream oss; + if (items.empty()) return ""; + oss << begin; bool first = true; - for (const auto& v : vec) { - if (!first) os << ", "; + for (const auto& item : items) { + if (item.empty()) continue; // Skip empty items + if (!first) oss << separator; first = false; - os << VariableName(v->from); + oss << item; } - os << end; - return true; + oss << end; + return oss.str(); +} + +std::string PrintArguments(const auto_vector>& vec, string begin, string end) { + return PrintArray(StringifyArguments(vec), begin, end); +} + +std::string PrintAttribute(Attribute attr) { + std::ostringstream oss; + std::visit([&oss](const auto& v) { oss << v; }, attr); + return oss.str(); } void PrintOp(const Op* op, std::ostringstream &os) { @@ -31,24 +49,15 @@ void PrintOp(const Op* op, std::ostringstream &os) { if (op->opcode == "const") { os << " = " << op->attributes.at("value"); } else { - os << " = " << op->opcode << "("; - // Print inputs - bool printed = PrintArguments(op->args->Get(ArgType::Input)->inputs, os, "", ""); - printed |= PrintArguments(op->args->Get(ArgType::Index)->inputs, os, ", index={", "}"); - printed |= PrintArguments(op->args->Get(ArgType::Memory)->inputs, os, ", memory={", "}"); - if (!op->attributes.empty()) { - if (printed) os << ", "; - os <<"{"; - bool first = true; - for (const auto& [key, value] : op->attributes) { - if (!first) os << ", "; - first = false; - os << key << ": "; - std::visit([&os](const auto& v) { os << v; }, value); - } - os << "}"; + std::string inputs = PrintArguments(op->args->Get(ArgType::Input)->inputs, "", ""); + std::string index = PrintArguments(op->args->Get(ArgType::Index)->inputs, "index={", "}"); + std::vector attributes; + for (const auto& [key, value] : op->attributes) { + attributes.push_back(key + ": " + PrintAttribute(value)); } - os <<")"; + std::string attributes_str = PrintArray(attributes, "{", "}"); + + os << " = " << op->opcode << "(" << PrintArray({inputs, index, attributes_str}) << ")"; } } @@ -63,19 +72,17 @@ std::string AddIndent(const std::string& str, int indent) { return indented; } + std::string PrintBlock(OpBlock &block) { auto oss = std::ostringstream(); for (auto it = block.begin(); it.valid(); it.next()) { PrintOp(*it, oss); if(it->blocks.size() > 0) { - bool first = true; - oss << " {\n"; + std::vector blocks; for (auto& sub_block : it->blocks) { - if (!first) oss << "{\n"; - oss<, std::vector>()> program_fn) { + StartExecutionContext(&context); + + auto [ins, outs] = program_fn(); + program_inputs = std::move(ins); + program_outputs = std::move(outs); + if (program_outputs.empty()) { + throw std::runtime_error("Program must have at least one output"); + } + + AssignVariableNames(*GetBaseBlock()); + EndExecutionContext(); +} + +std::string TFProgram::DebugPrint() const { + std::string program_header = "TFProgram(inputs=" + PrintArray(TransformVector(program_inputs, VariableName), "[", "]") + ") {\n"; + std::string inner_code = PrintBlock(*context.base_block); + inner_code += "return " + PrintArray(TransformVector(program_outputs, VariableName), "[", "]") + ";\n"; + return program_header + AddIndent(inner_code, 2) + "}\n"; +} +} From 8a4f18dac39cccd06da18f30a7e9ae7a4665c639 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz Date: Tue, 10 Jun 2025 05:43:42 +0200 Subject: [PATCH 10/44] Simple compilation pass --- TensorFrost/PybindModule.cpp | 2 +- TensorFrost/include/Compiler/Common.h | 4 +--- .../include/Compiler/OperationArguments.h | 3 +++ .../include/Compiler/OperationRegistry.h | 8 +++---- TensorFrost/include/Compiler/TFProgram.h | 23 +++++++++++++++++++ .../src/Compiler/OperationArguments.cpp | 20 ++++++++++++++++ 6 files changed, 52 insertions(+), 8 deletions(-) diff --git a/TensorFrost/PybindModule.cpp b/TensorFrost/PybindModule.cpp index 87d4c228..ef11eed6 100644 --- a/TensorFrost/PybindModule.cpp +++ b/TensorFrost/PybindModule.cpp @@ -151,7 +151,7 @@ PYBIND11_MODULE(TensorFrost, m) { }); return std::make_pair(inputs, outputs); }); - + program.Compile(); py::print(program.DebugPrint()); } diff --git a/TensorFrost/include/Compiler/Common.h b/TensorFrost/include/Compiler/Common.h index 62c64299..141967ad 100644 --- a/TensorFrost/include/Compiler/Common.h +++ b/TensorFrost/include/Compiler/Common.h @@ -11,8 +11,6 @@ #include #include #include -#include -#include namespace TensorFrost { extern "C" { @@ -104,7 +102,7 @@ struct Argument; using Attribute = std::variant; using AttributeMap = std::unordered_map; -using AttributeSpan = std::span; +using AttributeVector = std::vector; //ostringstream conversion for Attribute inline std::ostream& operator<<(std::ostream& os, const Attribute& attr) { diff --git a/TensorFrost/include/Compiler/OperationArguments.h b/TensorFrost/include/Compiler/OperationArguments.h index 66083aaf..ae2805c8 100644 --- a/TensorFrost/include/Compiler/OperationArguments.h +++ b/TensorFrost/include/Compiler/OperationArguments.h @@ -32,6 +32,9 @@ struct ArgumentManager { void SetAsOutput(Argument *arg); void RemoveOutput(Argument *arg); void SetArguments(ArgType type, std::vector args); + void Remove(ArgType type, int index); + void RemoveType(ArgType type); + void RemoveAll(); Arguments* Get(ArgType type) const; Arguments* operator[](ArgType type) const; diff --git a/TensorFrost/include/Compiler/OperationRegistry.h b/TensorFrost/include/Compiler/OperationRegistry.h index 55ebd860..5822365c 100644 --- a/TensorFrost/include/Compiler/OperationRegistry.h +++ b/TensorFrost/include/Compiler/OperationRegistry.h @@ -29,7 +29,7 @@ enum class OpProp { Set, }; -using FoldFn = std::function; +using FoldFn = std::function; [[noreturn]] inline void bad_arity(std::size_t expect, std::size_t got) { @@ -39,7 +39,7 @@ using FoldFn = std::function; template FoldFn make_fold1(F f) { - return [f = std::move(f)](AttributeSpan a) -> Attribute { + return [f = std::move(f)](AttributeVector a) -> Attribute { if (a.size() != 1) bad_arity(1, a.size()); return std::visit([&](auto &&x) -> Attribute { return f(std::forward(x)); @@ -49,7 +49,7 @@ FoldFn make_fold1(F f) { template FoldFn make_fold2(F f) { - return [f = std::move(f)](AttributeSpan a) -> Attribute { + return [f = std::move(f)](AttributeVector a) -> Attribute { if (a.size() != 2) bad_arity(2, a.size()); return std::visit([&](auto &&x, auto &&y) -> Attribute { return f(std::forward(x), std::forward(y)); @@ -59,7 +59,7 @@ FoldFn make_fold2(F f) { template FoldFn make_fold3(F f) { - return [f = std::move(f)](AttributeSpan a) -> Attribute { + return [f = std::move(f)](AttributeVector a) -> Attribute { if (a.size() != 3) bad_arity(3, a.size()); return std::visit([&](auto &&x, auto &&y, auto &&z) -> Attribute { return f(std::forward(x), std::forward(y), std::forward(z)); diff --git a/TensorFrost/include/Compiler/TFProgram.h b/TensorFrost/include/Compiler/TFProgram.h index 750244dd..783244e4 100644 --- a/TensorFrost/include/Compiler/TFProgram.h +++ b/TensorFrost/include/Compiler/TFProgram.h @@ -11,6 +11,29 @@ class TFProgram { std::vector program_outputs; TFProgram(std::function, std::vector>()> program_fn); + + void Compile() { + StartExecutionContext(&context); + //constant folding + ApplyOpTransform(*GetBaseBlock(), [](Op& op) { + OpSpec* spec = GetOpSpec(op.opcode); + if(spec->constant_fold) { // Specification has a constant folding function + AttributeVector inputs; + for (const auto& arg : op.args->Get(ArgType::Input)->inputs) { + if(!arg->from->attributes.contains("value")) { + return; // Skip if some argument does not have a constant value + } + inputs.push_back(arg->from->attributes.at("value")); + } + Attribute result = spec->constant_fold(inputs); + op.attributes["value"] = result; // Set the result as a constant value + op.opcode = "const"; // Change the opcode to constant + op.args->RemoveAll(); // Clear all arguments + } + }); + EndExecutionContext(); + } + std::string DebugPrint() const; }; } \ No newline at end of file diff --git a/TensorFrost/src/Compiler/OperationArguments.cpp b/TensorFrost/src/Compiler/OperationArguments.cpp index 049334ba..7d1711cc 100644 --- a/TensorFrost/src/Compiler/OperationArguments.cpp +++ b/TensorFrost/src/Compiler/OperationArguments.cpp @@ -71,6 +71,26 @@ void ArgumentManager::SetArguments(ArgType type, std::vector args) { } } +void ArgumentManager::Remove(ArgType type, int index) { + if (index < 0 || index >= type_args[(int)type]->inputs.size()) { + throw std::out_of_range("Index out of range for argument type " + ToString(type)); + } + type_args[(int)type]->RemoveInput(index); +} + +void ArgumentManager::RemoveType(ArgType type) { + auto& args = type_args[(int)type]; + for (size_t i = 0; i < args->inputs.size(); ++i) { + args->RemoveInput((int)i); + } +} + +void ArgumentManager::RemoveAll() { + for (int i = 0; i < (int)ArgType::Count; ++i) { + RemoveType((ArgType)i); + } +} + Arguments* ArgumentManager::Get(ArgType type) const { return type_args[(int)type].get(); } From ed1679a3a28bed11a4a505943e92ebeb9e4fe7d3 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz Date: Tue, 10 Jun 2025 05:52:30 +0200 Subject: [PATCH 11/44] Verbose print --- TensorFrost/src/Compiler/Printer.cpp | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/TensorFrost/src/Compiler/Printer.cpp b/TensorFrost/src/Compiler/Printer.cpp index e3e28400..fb87a357 100644 --- a/TensorFrost/src/Compiler/Printer.cpp +++ b/TensorFrost/src/Compiler/Printer.cpp @@ -38,6 +38,20 @@ std::string PrintArguments(const auto_vector>& vec, st return PrintArray(StringifyArguments(vec), begin, end); } +std::string PrintArguments(const Arguments* args) { + if (!args) return ""; + std::vector inputs = StringifyArguments(args->inputs); + std::vector outputs; + for (const auto& arg : args->used_at) { + if (arg.second->to) { + outputs.push_back(VariableName(arg.second->to)); + } + } + std::string inputs_str = PrintArray(inputs, "inputs={", "}"); + std::string outputs_str = PrintArray(outputs, "outputs={", "}"); + return "[" + inputs_str + ", " + outputs_str + "]"; +} + std::string PrintAttribute(Attribute attr) { std::ostringstream oss; std::visit([&oss](const auto& v) { oss << v; }, attr); @@ -49,8 +63,10 @@ void PrintOp(const Op* op, std::ostringstream &os) { if (op->opcode == "const") { os << " = " << op->attributes.at("value"); } else { - std::string inputs = PrintArguments(op->args->Get(ArgType::Input)->inputs, "", ""); - std::string index = PrintArguments(op->args->Get(ArgType::Index)->inputs, "index={", "}"); + // std::string inputs = PrintArguments(op->args->Get(ArgType::Input)->inputs, "", ""); + // std::string index = PrintArguments(op->args->Get(ArgType::Index)->inputs, "index={", "}"); + std::string inputs = "args=" + PrintArguments(op->args->Get(ArgType::Input)); + std::string index = "index=" + PrintArguments(op->args->Get(ArgType::Index)); std::vector attributes; for (const auto& [key, value] : op->attributes) { attributes.push_back(key + ": " + PrintAttribute(value)); From 8254078f5f0730b7e2de4f80518f4562c786d032 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz Date: Tue, 10 Jun 2025 21:42:01 +0200 Subject: [PATCH 12/44] Make value wrapper --- TensorFrost/include/Compiler/Common.h | 1 + TensorFrost/include/Compiler/Operation.h | 3 -- TensorFrost/include/Compiler/Overloads.h | 63 ++++++---------------- TensorFrost/include/Compiler/TFProgram.h | 24 ++------- TensorFrost/include/Compiler/Value.h | 24 +++++++++ TensorFrost/include/TensorFrost.h | 1 + TensorFrost/src/Compiler/Overloads.cpp | 66 ++++++++++++++---------- TensorFrost/src/Compiler/TFProgram.cpp | 46 +++++++++++++++++ TensorFrost/src/Compiler/Value.cpp | 27 ++++++++++ 9 files changed, 155 insertions(+), 100 deletions(-) create mode 100644 TensorFrost/include/Compiler/Value.h create mode 100644 TensorFrost/src/Compiler/Value.cpp diff --git a/TensorFrost/include/Compiler/Common.h b/TensorFrost/include/Compiler/Common.h index 141967ad..ce292e3b 100644 --- a/TensorFrost/include/Compiler/Common.h +++ b/TensorFrost/include/Compiler/Common.h @@ -99,6 +99,7 @@ struct OpBlock; class OpBlockIterator; struct ArgumentManager; struct Argument; +class Value; using Attribute = std::variant; using AttributeMap = std::unordered_map; diff --git a/TensorFrost/include/Compiler/Operation.h b/TensorFrost/include/Compiler/Operation.h index 42824d61..934b40a0 100644 --- a/TensorFrost/include/Compiler/Operation.h +++ b/TensorFrost/include/Compiler/Operation.h @@ -24,9 +24,6 @@ struct Op { OpBlock* NewBlock(); OpBlock& GetBlock(int index = 0); - Op& operator[](int index); - Op& operator[](std::vector indices); - void AddAttribute(const std::string& name, const Attribute& value); void ChangeAttribute(const std::string& name, const Attribute& value); void GetAttribute(const std::string& name, Attribute& value) const; diff --git a/TensorFrost/include/Compiler/Overloads.h b/TensorFrost/include/Compiler/Overloads.h index 38c2624a..967eb131 100644 --- a/TensorFrost/include/Compiler/Overloads.h +++ b/TensorFrost/include/Compiler/Overloads.h @@ -1,5 +1,6 @@ #pragma once #include "Operation.h" +#include "Value.h" namespace TensorFrost { Op& make_op(std::string op, std::vector ids, std::vector args); @@ -10,54 +11,20 @@ Op& constant(uint value); Op& constant(float value); Op& constant(bool value); -template concept Num = std::is_arithmetic_v>; -template concept IsOp = std::same_as, Op>; - -template -inline Op* as_op(T v) { - using D = std::remove_cvref_t; - using Target = - std::conditional_t, bool, - std::conditional_t, float, - std::conditional_t, unsigned int, - int>>>; - return &constant(static_cast(v)); +#define UNARY_OPERATOR(op_, opname_) inline Value operator op_(const Value& x) { \ + return Value(&func_op(opname_, {x.op})); \ } -inline Op* as_op(const Op& x) { return &const_cast(x); } - -#define UNARY_OPERATOR(op, opname) \ -template \ -requires IsOp \ -inline Op& operator op(const T& x) { \ - return func_op(opname, {as_op(x)}); \ +#define BINARY_OPERATOR(op_, opname_) inline Value operator op_(const Value& x, const Value& y) { \ + return Value(&func_op(opname_, {x.op, y.op})); \ } - -#define BINARY_OPERATOR(op, opname) \ -template \ -requires (IsOp || IsOp) \ -inline Op& operator op(const T& x, const U& y) { \ - return func_op(opname, {as_op(x), as_op(y)}); \ +#define UNARY_FUNCTION(func_, opname_) inline Value func_(const Value& x) { \ + return Value(&func_op(opname_, {x.op})); \ } - -#define UNARY_FUNCTION(func, opname) \ -template \ -requires IsOp \ -inline Op& func(const T& x) { \ - return func_op(opname, {as_op(x)}); \ +#define BINARY_FUNCTION(func_, opname_) inline Value func_(const Value& x, const Value& y) { \ + return Value(&func_op(opname_, {x.op, y.op})); \ } - -#define BINARY_FUNCTION(func, opname) \ -template \ -requires (IsOp || IsOp) \ -inline Op& func(const T& x, const U& y) { \ - return func_op(opname, {as_op(x), as_op(y)}); \ -} - -#define TERNARY_FUNCTION(func, opname) \ -template \ -requires (IsOp || IsOp || IsOp) \ -inline Op& func(const T& x, const U& y, const V& z) { \ - return func_op(opname, {as_op(x), as_op(y), as_op(z)}); \ +#define TERNARY_FUNCTION(func_, opname_) inline Value func_(const Value& x, const Value& y, const Value& z) { \ + return Value(&func_op(opname_, {x.op, y.op, z.op})); \ } UNARY_OPERATOR(+, "pos") @@ -137,8 +104,8 @@ TERNARY_FUNCTION(smoothstep, "smoothstep") TERNARY_FUNCTION(select, "ternary") TERNARY_FUNCTION(fma, "fma") -Op& unpack_tuple(const Op& x, int index = 0); -Op& vmap(std::vector shape, std::function body); -Op& memory(std::vector shape, TFDataFormat type); -Op& load_at_index(const Op& mem, std::vector indices); +Value unpack_tuple(Value x, int index = 0); +Value vmap(std::vector shape, std::function body); +Value memory(std::vector shape, TFDataFormat type); +Value load_at_index(Value mem, std::vector indices); } diff --git a/TensorFrost/include/Compiler/TFProgram.h b/TensorFrost/include/Compiler/TFProgram.h index 783244e4..ce2a62d4 100644 --- a/TensorFrost/include/Compiler/TFProgram.h +++ b/TensorFrost/include/Compiler/TFProgram.h @@ -12,27 +12,9 @@ class TFProgram { TFProgram(std::function, std::vector>()> program_fn); - void Compile() { - StartExecutionContext(&context); - //constant folding - ApplyOpTransform(*GetBaseBlock(), [](Op& op) { - OpSpec* spec = GetOpSpec(op.opcode); - if(spec->constant_fold) { // Specification has a constant folding function - AttributeVector inputs; - for (const auto& arg : op.args->Get(ArgType::Input)->inputs) { - if(!arg->from->attributes.contains("value")) { - return; // Skip if some argument does not have a constant value - } - inputs.push_back(arg->from->attributes.at("value")); - } - Attribute result = spec->constant_fold(inputs); - op.attributes["value"] = result; // Set the result as a constant value - op.opcode = "const"; // Change the opcode to constant - op.args->RemoveAll(); // Clear all arguments - } - }); - EndExecutionContext(); - } + void Compile(); + void ConstantFold(); + void RemoveUnused(); std::string DebugPrint() const; }; diff --git a/TensorFrost/include/Compiler/Value.h b/TensorFrost/include/Compiler/Value.h new file mode 100644 index 00000000..d112f727 --- /dev/null +++ b/TensorFrost/include/Compiler/Value.h @@ -0,0 +1,24 @@ +#pragma once +#include "Operation.h" + +namespace TensorFrost { + +// Op wrapper class for overloaded mathematics and operations +class Value { +public: + Op* op = nullptr; + + Value(Value& other) : op(other.op) {} + Value(Op* operation); + Value(float value); + Value(int value); + Value(uint value); + Value(bool value); + + Value operator[](int index); + Value operator[](std::vector indices); +}; + +std::vector values_to_ops(const std::vector& values); +std::vector ops_to_values(const std::vector& ops); +} \ No newline at end of file diff --git a/TensorFrost/include/TensorFrost.h b/TensorFrost/include/TensorFrost.h index 8ce5653f..13528d8d 100644 --- a/TensorFrost/include/TensorFrost.h +++ b/TensorFrost/include/TensorFrost.h @@ -5,5 +5,6 @@ #include "Compiler/OperationBlocks.h" #include "Compiler/OperationArguments.h" #include "Compiler/Overloads.h" +#include "Compiler/Value.h" #include "Compiler/Printer.h" #include "Compiler/TFProgram.h" \ No newline at end of file diff --git a/TensorFrost/src/Compiler/Overloads.cpp b/TensorFrost/src/Compiler/Overloads.cpp index d83bdaab..c9abfcef 100644 --- a/TensorFrost/src/Compiler/Overloads.cpp +++ b/TensorFrost/src/Compiler/Overloads.cpp @@ -56,37 +56,47 @@ Op& constant(bool value) { const_op.type = TFTypeBool32; return const_op; } - -Op& unpack_tuple(const Op &x, int index) { - Op& elem = func_op("unpack_tuple_int", {as_op(x)}); +// +// Op& unpack_tuple(const Op &x, int index) { +// Op& elem = func_op("unpack_tuple_int", {as_op(x)}); +// elem.attributes["index"] = index; // Default index +// return elem; +// } +// +// Op& vmap(std::vector shape, std::function body) { +// Op& par_op = func_op("vmap", std::move(shape)); +// GetContext()->BeginCursor(par_op.GetBlock().begin()); +// body(par_op); +// GetContext()->EndCursor(); +// return par_op; +// } +// +// Op& memory(std::vector shape, TFDataFormat type) { +// Op& mem_op = func_op("memory", std::move(shape)); +// mem_op.type = type; +// return mem_op; +// } +// +// Op& load_at_index(const Op& mem, std::vector indices) { +// return make_op("load", indices, {as_op(mem)}); +// } + +Value unpack_tuple(Value x, int index) { + if (x.op->type.type != TFType::Tuple) { + throw std::runtime_error("Cannot unpack non-tuple value"); + } + Op& elem = func_op("unpack_tuple_int", {x.op}); elem.attributes["index"] = index; // Default index - return elem; + return Value(&elem); } -Op& vmap(std::vector shape, std::function body) { - Op& par_op = func_op("vmap", std::move(shape)); +Value vmap(std::vector shape, std::function body) { + Op& par_op = func_op("vmap", {}); + for (const auto& dim : shape) { + par_op.args->AddArgument(ArgType::Input, dim.op); + } GetContext()->BeginCursor(par_op.GetBlock().begin()); - body(par_op); + body(Value(&par_op)); GetContext()->EndCursor(); - return par_op; -} - -Op& memory(std::vector shape, TFDataFormat type) { - Op& mem_op = func_op("memory", std::move(shape)); - mem_op.type = type; - return mem_op; -} - -Op& load_at_index(const Op& mem, std::vector indices) { - return make_op("load", indices, {as_op(mem)}); -} - -Op& Op::operator[](int index) { - return unpack_tuple(*this, index); -} - -Op& Op::operator[](std::vector indices) { - return load_at_index(*this, std::move(indices)); -} - + return Value(&par_op); } diff --git a/TensorFrost/src/Compiler/TFProgram.cpp b/TensorFrost/src/Compiler/TFProgram.cpp index f99aa533..9ef9e96f 100644 --- a/TensorFrost/src/Compiler/TFProgram.cpp +++ b/TensorFrost/src/Compiler/TFProgram.cpp @@ -15,6 +15,52 @@ TFProgram::TFProgram(std::function, std::vector> EndExecutionContext(); } +void TFProgram::Compile() { + StartExecutionContext(&context); + ConstantFold(); + EndExecutionContext(); +} + +void TFProgram::ConstantFold() { + ApplyOpTransform(*GetBaseBlock(), [](Op& op) { + OpSpec* spec = GetOpSpec(op.opcode); + if(!spec->constant_fold) return; // Skip if no constant folding is defined for this operation + AttributeVector inputs; + for (const auto& arg : op.args->Get(ArgType::Input)->inputs) { + if(!arg->from->attributes.contains("value")) { + return; // Skip if some argument does not have a constant value + } + inputs.push_back(arg->from->attributes.at("value")); + } + Attribute result = spec->constant_fold(inputs); + op.attributes["value"] = result; // Set the result as a constant value + op.opcode = "const"; // Change the opcode to constant + op.args->RemoveAll(); // Clear all arguments + }); +} + +void FindUsedOps(std::set& used_ops, OpBlock& block) { + for (auto& op : block.ops) { + if (op->used_at.empty()) continue; // Skip unused operations + used_ops.insert(op.get()); + for (auto& sub_block : op->blocks) { + FindUsedOps(used_ops, *sub_block); + } + } +} + +void TFProgram::RemoveUnused() { + StartExecutionContext(&context); + + ApplyOpTransform(*GetBaseBlock(), [](Op& op) { + // if (op.opcode == "const") return; // Skip constants + // if (op.args->Get(ArgType::Input)->inputs.empty()) return; // Skip operations with no inputs + // if (op.args->Get(ArgType::Output)->inputs.empty()) return; // Skip operations with no outputs + // if (op.used_at.empty()) return; // Skip unused operations + }); + EndExecutionContext(); +} + std::string TFProgram::DebugPrint() const { std::string program_header = "TFProgram(inputs=" + PrintArray(TransformVector(program_inputs, VariableName), "[", "]") + ") {\n"; std::string inner_code = PrintBlock(*context.base_block); diff --git a/TensorFrost/src/Compiler/Value.cpp b/TensorFrost/src/Compiler/Value.cpp new file mode 100644 index 00000000..2419956c --- /dev/null +++ b/TensorFrost/src/Compiler/Value.cpp @@ -0,0 +1,27 @@ +#include "Compiler/Operation.h" +#include "Compiler/ExecutionContext.h" +#include "Compiler/Value.h" +using namespace std; + +namespace TensorFrost { + +Value::Value(Op* operation) : op(operation) { + if (!op) { + throw std::runtime_error("Value cannot be constructed with a null Op pointer"); + } +} + +Value::Value(float value) : op(&constant(value)) {} +Value::Value(int value) : op(&constant(value)) {} +Value::Value(uint value) : op(&constant(value)) {} +Value::Value(bool value) : op(&constant(value)) {} + +Value Value::operator[](int index) { + return unpack_tuple(*this, index); +} + +Op& Op::operator[](std::vector indices) { + return load_at_index(*this, std::move(indices)); +} + +} \ No newline at end of file From 36fc4d5d58724bddd01a21730b272a014ea83657 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz Date: Tue, 10 Jun 2025 22:49:10 +0200 Subject: [PATCH 13/44] Fix value wrapper --- TensorFrost/PybindModule.cpp | 26 ++-- TensorFrost/include/Compiler/Common.h | 4 +- TensorFrost/include/Compiler/Overloads.h | 111 ++---------------- TensorFrost/include/Compiler/TFProgram.h | 6 +- TensorFrost/include/Compiler/Value.h | 30 ++++- TensorFrost/src/Compiler/Common.cpp | 13 ++ .../src/Compiler/OperationRegistry.cpp | 8 +- TensorFrost/src/Compiler/Overloads.cpp | 102 ++++++---------- TensorFrost/src/Compiler/TFProgram.cpp | 24 ++-- TensorFrost/src/Compiler/Value.cpp | 106 +++++++++++++++-- 10 files changed, 223 insertions(+), 207 deletions(-) diff --git a/TensorFrost/PybindModule.cpp b/TensorFrost/PybindModule.cpp index ef11eed6..c0d9c062 100644 --- a/TensorFrost/PybindModule.cpp +++ b/TensorFrost/PybindModule.cpp @@ -134,19 +134,19 @@ PYBIND11_MODULE(TensorFrost, m) { // TEST CODE TFProgram program([]() -> auto { - std::vector inputs; - std::vector outputs; - Op& a = constant(5); - Op& b = constant(10); - Op& c = a + b * 3; - Op& mem = memory({&a, &b, &c}, TFTypeFloat32); - inputs.push_back(&mem); - vmap({&a, &b, &c}, [&](Op& ids0) { - Op& imem = toint(mem); - Op& d = c + b + ids0[0] * imem; - vmap({&c, &c}, [&](Op& ids1) { - Op& m = d * c * imem[{&ids1[1], &ids1[0], &ids0[0]}]; - outputs.push_back(&m); + std::vector inputs; + std::vector outputs; + Value a = 5; + Value b = 10; + Value c = a + b * 3; + Value mem = memory({a, b, c}, TFTypeFloat32); + inputs.push_back(mem); + vmap({a, b, c}, [&](Value ids0) { + Value imem = toint(mem); + Value d = c + b + ids0[0] * imem; + vmap({c, c}, [&](Value ids1) { + Value m = d * c * imem[{ids1[1], ids1[0], ids0[0]}]; + outputs.push_back(m); }); }); return std::make_pair(inputs, outputs); diff --git a/TensorFrost/include/Compiler/Common.h b/TensorFrost/include/Compiler/Common.h index ce292e3b..0f01e8e3 100644 --- a/TensorFrost/include/Compiler/Common.h +++ b/TensorFrost/include/Compiler/Common.h @@ -101,10 +101,12 @@ struct ArgumentManager; struct Argument; class Value; -using Attribute = std::variant; +using Attribute = std::variant; using AttributeMap = std::unordered_map; using AttributeVector = std::vector; +TFDataFormat GetTypeFromAttribute(const Attribute& attr); + //ostringstream conversion for Attribute inline std::ostream& operator<<(std::ostream& os, const Attribute& attr) { std::visit([&os](const auto& v) { os << v; }, attr); diff --git a/TensorFrost/include/Compiler/Overloads.h b/TensorFrost/include/Compiler/Overloads.h index 967eb131..1c7d640f 100644 --- a/TensorFrost/include/Compiler/Overloads.h +++ b/TensorFrost/include/Compiler/Overloads.h @@ -3,109 +3,20 @@ #include "Value.h" namespace TensorFrost { -Op& make_op(std::string op, std::vector ids, std::vector args); -Op& func_op(const std::string& name, std::vector args = {}); - -Op& constant(int value); -Op& constant(uint value); -Op& constant(float value); -Op& constant(bool value); - -#define UNARY_OPERATOR(op_, opname_) inline Value operator op_(const Value& x) { \ - return Value(&func_op(opname_, {x.op})); \ -} -#define BINARY_OPERATOR(op_, opname_) inline Value operator op_(const Value& x, const Value& y) { \ - return Value(&func_op(opname_, {x.op, y.op})); \ -} -#define UNARY_FUNCTION(func_, opname_) inline Value func_(const Value& x) { \ - return Value(&func_op(opname_, {x.op})); \ -} -#define BINARY_FUNCTION(func_, opname_) inline Value func_(const Value& x, const Value& y) { \ - return Value(&func_op(opname_, {x.op, y.op})); \ -} -#define TERNARY_FUNCTION(func_, opname_) inline Value func_(const Value& x, const Value& y, const Value& z) { \ - return Value(&func_op(opname_, {x.op, y.op, z.op})); \ -} - -UNARY_OPERATOR(+, "pos") -UNARY_OPERATOR(-, "neg") -UNARY_OPERATOR(~, "not") -UNARY_OPERATOR(!, "lnot") - -BINARY_OPERATOR(+, "add") -BINARY_OPERATOR(-, "sub") -BINARY_OPERATOR(*, "mul") -BINARY_OPERATOR(/, "div") -BINARY_OPERATOR(%, "mod") -BINARY_OPERATOR(&, "and") -BINARY_OPERATOR(|, "or") -BINARY_OPERATOR(^, "xor") -BINARY_OPERATOR(<<, "lshift") -BINARY_OPERATOR(>>, "rshift") -BINARY_OPERATOR(==, "eq") -BINARY_OPERATOR(!=, "neq") -BINARY_OPERATOR(<, "lt") -BINARY_OPERATOR(<=, "lte") -BINARY_OPERATOR(>, "gt") -BINARY_OPERATOR(>=, "gte") -BINARY_OPERATOR(&&, "land") -BINARY_OPERATOR(||, "lor") - -UNARY_FUNCTION(copy, "copy") -UNARY_FUNCTION(sin, "sin") -UNARY_FUNCTION(cos, "cos") -UNARY_FUNCTION(tan, "tan") -UNARY_FUNCTION(asin, "asin") -UNARY_FUNCTION(acos, "acos") -UNARY_FUNCTION(atan, "atan") -UNARY_FUNCTION(sinh, "sinh") -UNARY_FUNCTION(cosh, "cosh") -UNARY_FUNCTION(tanh, "tanh") -UNARY_FUNCTION(asinh, "asinh") -UNARY_FUNCTION(acosh, "acosh") -UNARY_FUNCTION(atanh, "atanh") -UNARY_FUNCTION(exp, "exp") -UNARY_FUNCTION(log, "log") -UNARY_FUNCTION(log2, "log2") -UNARY_FUNCTION(exp2, "exp2") -UNARY_FUNCTION(sqrt, "sqrt") -UNARY_FUNCTION(sqr, "sqr") -UNARY_FUNCTION(rsqrt, "rsqrt") -UNARY_FUNCTION(rcp, "rcp") -UNARY_FUNCTION(abs, "abs") -UNARY_FUNCTION(sign, "sign") -UNARY_FUNCTION(floor, "floor") -UNARY_FUNCTION(ceil, "ceil") -UNARY_FUNCTION(round, "round") -UNARY_FUNCTION(trunc, "trunc") -UNARY_FUNCTION(frac, "frac") -UNARY_FUNCTION(pcg, "pcg") -UNARY_FUNCTION(pcgf, "pcgf") -UNARY_FUNCTION(reversebits, "reversebits") -UNARY_FUNCTION(tofloat, "tofloat") -UNARY_FUNCTION(toint, "toint") -UNARY_FUNCTION(touint, "touint") -UNARY_FUNCTION(tobool, "tobool") -UNARY_FUNCTION(asfloat, "asfloat") -UNARY_FUNCTION(asint, "asint") -UNARY_FUNCTION(asuint, "asuint") -UNARY_FUNCTION(clamp, "clamp") - -BINARY_FUNCTION(pow, "pow") -BINARY_FUNCTION(min, "min") -BINARY_FUNCTION(max, "max") -BINARY_FUNCTION(mod, "mod") -BINARY_FUNCTION(modf, "modf") -BINARY_FUNCTION(atan2, "atan2") -BINARY_FUNCTION(grad, "backwards_grad") - -TERNARY_FUNCTION(lerp, "lerp") -TERNARY_FUNCTION(smoothstep, "smoothstep") -TERNARY_FUNCTION(select, "ternary") -TERNARY_FUNCTION(fma, "fma") +Value make_op(std::string op, std::vector ids = {}, std::vector args = {}); +Value func_op(const std::string &name, std::vector args = {}); +Value constant(int value); +Value constant(uint value); +Value constant(float value); +Value constant(bool value); Value unpack_tuple(Value x, int index = 0); Value vmap(std::vector shape, std::function body); Value memory(std::vector shape, TFDataFormat type); Value load_at_index(Value mem, std::vector indices); + +inline Value toint(Value x) { return func_op("toint", {x}); } +inline Value tofloat(Value x) { return func_op("tofloat", {x}); } +inline Value touint(Value x) { return func_op("touint", {x}); } +inline Value tobool(Value x) { return func_op("tobool", {x}); } } diff --git a/TensorFrost/include/Compiler/TFProgram.h b/TensorFrost/include/Compiler/TFProgram.h index ce2a62d4..c55ffb75 100644 --- a/TensorFrost/include/Compiler/TFProgram.h +++ b/TensorFrost/include/Compiler/TFProgram.h @@ -7,10 +7,10 @@ namespace TensorFrost { class TFProgram { public: ExecutionContext context; - std::vector program_inputs; - std::vector program_outputs; + std::vector program_inputs; + std::vector program_outputs; - TFProgram(std::function, std::vector>()> program_fn); + TFProgram(std::function, std::vector>()> program_fn); void Compile(); void ConstantFold(); diff --git a/TensorFrost/include/Compiler/Value.h b/TensorFrost/include/Compiler/Value.h index d112f727..4b35ee01 100644 --- a/TensorFrost/include/Compiler/Value.h +++ b/TensorFrost/include/Compiler/Value.h @@ -8,15 +8,39 @@ class Value { public: Op* op = nullptr; - Value(Value& other) : op(other.op) {} Value(Op* operation); Value(float value); Value(int value); Value(uint value); Value(bool value); + Value(const Value& other) : op(other.op) {} - Value operator[](int index); - Value operator[](std::vector indices); + // indexed access + Value operator[](int index) const; + Value operator[](const std::vector& indices) const; + + // binary ops take const ref and are const themselves + Value operator+(const Value& other) const; + Value operator-(const Value& other) const; + Value operator*(const Value& other) const; + Value operator/(const Value& other) const; + Value operator%(const Value& other) const; + Value operator==(const Value& other) const; + Value operator!=(const Value& other) const; + Value operator<(const Value& other) const; + Value operator<=(const Value& other) const; + Value operator>(const Value& other) const; + Value operator>=(const Value& other) const; + Value operator&&(const Value& other) const; + Value operator||(const Value& other) const; + Value operator<<(const Value& other) const; + Value operator>>(const Value& other) const; + + // unary ops + Value operator!() const; + Value operator-() const; + Value operator+() const; + Value operator~() const; }; std::vector values_to_ops(const std::vector& values); diff --git a/TensorFrost/src/Compiler/Common.cpp b/TensorFrost/src/Compiler/Common.cpp index b78a0a84..f16071ed 100644 --- a/TensorFrost/src/Compiler/Common.cpp +++ b/TensorFrost/src/Compiler/Common.cpp @@ -20,4 +20,17 @@ bool TFDataFormat::operator<(const TFDataFormat &other) const { bool TFDataFormat::operator>(const TFDataFormat &other) const { return GetHash() > other.GetHash(); } + +TFDataFormat GetTypeFromAttribute(const Attribute& attr) { + if (std::holds_alternative(attr)) { + return TFTypeInt32; + } else if (std::holds_alternative(attr)) { + return TFTypeUint32; + } else if (std::holds_alternative(attr)) { + return TFTypeFloat32; + } else if (std::holds_alternative(attr)) { + return TFTypeBool32; + } + throw std::runtime_error("Unsupported attribute type for TFDataFormat conversion"); +} } diff --git a/TensorFrost/src/Compiler/OperationRegistry.cpp b/TensorFrost/src/Compiler/OperationRegistry.cpp index c4236af3..37b26ca2 100644 --- a/TensorFrost/src/Compiler/OperationRegistry.cpp +++ b/TensorFrost/src/Compiler/OperationRegistry.cpp @@ -67,7 +67,13 @@ OverloadsMap GenerateOverloadMap(const std::string& input) { OpSpec{ .name = op_name, .overloads = GenerateOverloadMap(overload_str), .op_class = operation_class, __VA_ARGS__ } #define BIN_OP_FOLD(op) \ - make_fold2([](auto a, auto b) { return a op b; }) +make_fold2([](auto a, auto b) { \ + if constexpr (std::is_same_v || std::is_same_v) { \ + return static_cast(a) op static_cast(b); \ + } else { \ + return a op b; \ + } \ +}) #define UN_OP_FOLD(op) \ make_fold1([](auto a) { return op a; }) diff --git a/TensorFrost/src/Compiler/Overloads.cpp b/TensorFrost/src/Compiler/Overloads.cpp index c9abfcef..9fd0fac2 100644 --- a/TensorFrost/src/Compiler/Overloads.cpp +++ b/TensorFrost/src/Compiler/Overloads.cpp @@ -1,102 +1,74 @@ #include "Compiler/Operation.h" #include "Compiler/ExecutionContext.h" +#include "Compiler/Value.h" using namespace std; namespace TensorFrost { // General function to create an Op instance in the current execution context -Op& make_op(std::string op, std::vector ids, std::vector args) { +Value make_op(std::string op, std::vector ids, std::vector args) { OpSpec* spec = GetOpSpec(op); vector arg_types; for (const auto& arg : args) { - arg_types.push_back(arg->type); + arg_types.push_back(arg.op->type); } TFDataFormat output_type = spec->GetOutputType(arg_types); Op* op_instance = new Op(op); op_instance->type = output_type; - op_instance->args->SetArguments(ArgType::Index, ids); - op_instance->args->SetArguments(ArgType::Input, args); + op_instance->args->SetArguments(ArgType::Index, values_to_ops(ids)); + op_instance->args->SetArguments(ArgType::Input, values_to_ops(args)); // Create blocks for (int i = 0; i < spec->blocks; ++i) { op_instance->NewBlock(); } - return GetContext()->Add(std::unique_ptr(op_instance)); + return Value(&GetContext()->Add(std::unique_ptr(op_instance))); } -Op & func_op(const std::string &name, std::vector args) { +Value func_op(const std::string &name, std::vector args) { return make_op(name, {}, std::move(args)); } -Op& constant(int value) { - Op& const_op = func_op("const"); - const_op.attributes["value"] = value; - const_op.type = TFTypeInt32; +Value constant(Attribute value) { + Value const_op = func_op("const"); + const_op.op->attributes["value"] = value; + const_op.op->type = GetTypeFromAttribute(value); return const_op; } -Op& constant(uint value) { - Op& const_op = func_op("const"); - const_op.attributes["value"] = value; - const_op.type = TFTypeUint32; - return const_op; -} - -Op& constant(float value) { - Op& const_op = func_op("const"); - const_op.attributes["value"] = value; - const_op.type = TFTypeFloat32; - return const_op; -} - -Op& constant(bool value) { - Op& const_op = func_op("const"); - const_op.attributes["value"] = value; - const_op.type = TFTypeBool32; - return const_op; -} -// -// Op& unpack_tuple(const Op &x, int index) { -// Op& elem = func_op("unpack_tuple_int", {as_op(x)}); -// elem.attributes["index"] = index; // Default index -// return elem; -// } -// -// Op& vmap(std::vector shape, std::function body) { -// Op& par_op = func_op("vmap", std::move(shape)); -// GetContext()->BeginCursor(par_op.GetBlock().begin()); -// body(par_op); -// GetContext()->EndCursor(); -// return par_op; -// } -// -// Op& memory(std::vector shape, TFDataFormat type) { -// Op& mem_op = func_op("memory", std::move(shape)); -// mem_op.type = type; -// return mem_op; -// } -// -// Op& load_at_index(const Op& mem, std::vector indices) { -// return make_op("load", indices, {as_op(mem)}); -// } +Value constant(int value) { return constant(Attribute(value)); } +Value constant(uint value) { return constant(Attribute(value)); } +Value constant(float value) { return constant(Attribute(value)); } +Value constant(bool value) { return constant(Attribute(value)); } Value unpack_tuple(Value x, int index) { - if (x.op->type.type != TFType::Tuple) { + if (x.op->type != TFTypeTuple) { throw std::runtime_error("Cannot unpack non-tuple value"); } - Op& elem = func_op("unpack_tuple_int", {x.op}); - elem.attributes["index"] = index; // Default index - return Value(&elem); + Value elem = func_op("unpack_tuple_int", {x}); + elem.op->attributes["index"] = index; // Default index + return elem; } Value vmap(std::vector shape, std::function body) { - Op& par_op = func_op("vmap", {}); - for (const auto& dim : shape) { - par_op.args->AddArgument(ArgType::Input, dim.op); - } - GetContext()->BeginCursor(par_op.GetBlock().begin()); - body(Value(&par_op)); + Value par_op = func_op("vmap", shape); + GetContext()->BeginCursor(par_op.op->GetBlock().begin()); + body(par_op); GetContext()->EndCursor(); - return Value(&par_op); + return par_op; +} + +Value memory(std::vector shape, TFDataFormat type) { + Value mem_op = func_op("memory", std::move(shape)); + mem_op.op->type = type; + return mem_op; +} + +Value load_at_index(Value mem, std::vector indices) { + if (mem.op->type.type == TFType::None) { + throw std::runtime_error("Cannot load from a None type memory"); + } + return make_op("load", indices, {mem}); +} } diff --git a/TensorFrost/src/Compiler/TFProgram.cpp b/TensorFrost/src/Compiler/TFProgram.cpp index 9ef9e96f..8eed77e9 100644 --- a/TensorFrost/src/Compiler/TFProgram.cpp +++ b/TensorFrost/src/Compiler/TFProgram.cpp @@ -1,7 +1,7 @@ #include "Compiler/TFProgram.h" namespace TensorFrost { -TFProgram::TFProgram(std::function, std::vector>()> program_fn) { +TFProgram::TFProgram(std::function, std::vector>()> program_fn) { StartExecutionContext(&context); auto [ins, outs] = program_fn(); @@ -39,15 +39,15 @@ void TFProgram::ConstantFold() { }); } -void FindUsedOps(std::set& used_ops, OpBlock& block) { - for (auto& op : block.ops) { - if (op->used_at.empty()) continue; // Skip unused operations - used_ops.insert(op.get()); - for (auto& sub_block : op->blocks) { - FindUsedOps(used_ops, *sub_block); - } - } -} +// void FindUsedOps(std::set& used_ops, OpBlock& block) { +// for (auto& op : block.ops) { +// if (op->used_at.empty()) continue; // Skip unused operations +// used_ops.insert(op.get()); +// for (auto& sub_block : op->blocks) { +// FindUsedOps(used_ops, *sub_block); +// } +// } +// } void TFProgram::RemoveUnused() { StartExecutionContext(&context); @@ -62,9 +62,9 @@ void TFProgram::RemoveUnused() { } std::string TFProgram::DebugPrint() const { - std::string program_header = "TFProgram(inputs=" + PrintArray(TransformVector(program_inputs, VariableName), "[", "]") + ") {\n"; + std::string program_header = "TFProgram(inputs=" + PrintArray(TransformVector(values_to_ops(program_inputs), VariableName), "[", "]") + ") {\n"; std::string inner_code = PrintBlock(*context.base_block); - inner_code += "return " + PrintArray(TransformVector(program_outputs, VariableName), "[", "]") + ";\n"; + inner_code += "return " + PrintArray(TransformVector(values_to_ops(program_outputs), VariableName), "[", "]") + ";\n"; return program_header + AddIndent(inner_code, 2) + "}\n"; } } diff --git a/TensorFrost/src/Compiler/Value.cpp b/TensorFrost/src/Compiler/Value.cpp index 2419956c..f3c4cb24 100644 --- a/TensorFrost/src/Compiler/Value.cpp +++ b/TensorFrost/src/Compiler/Value.cpp @@ -11,17 +11,105 @@ Value::Value(Op* operation) : op(operation) { } } -Value::Value(float value) : op(&constant(value)) {} -Value::Value(int value) : op(&constant(value)) {} -Value::Value(uint value) : op(&constant(value)) {} -Value::Value(bool value) : op(&constant(value)) {} +Value::Value(float value) : op(constant(value).op) {} +Value::Value(int value) : op(constant(value).op) {} +Value::Value(uint value) : op(constant(value).op) {} +Value::Value(bool value) : op(constant(value).op) { -Value Value::operator[](int index) { - return unpack_tuple(*this, index); } -Op& Op::operator[](std::vector indices) { - return load_at_index(*this, std::move(indices)); +std::vector values_to_ops(const std::vector& values) { + std::vector ops; + ops.reserve(values.size()); + for (const auto& value : values) { + if (value.op) { + ops.push_back(value.op); + } else { + throw std::runtime_error("Value contains a null Op pointer"); + } + } + return ops; +} + +std::vector ops_to_values(const std::vector& ops) { + std::vector values; + values.reserve(ops.size()); + for (const auto& op : ops) { + if (op) { + values.emplace_back(op); + } else { + throw std::runtime_error("Op pointer in vector is null"); + } + } + return values; +} + +Value Value::operator+(const Value& other) const { + return func_op("add", {op, other.op}); +} +Value Value::operator-(const Value& other) const { + return func_op("sub", {op, other.op}); +} +Value Value::operator*(const Value& other) const { + return func_op("mul", {op, other.op}); +} +Value Value::operator/(const Value& other) const { + return func_op("div", {op, other.op}); +} +Value Value::operator%(const Value& other) const { + return func_op("mod", {op, other.op}); +} +Value Value::operator==(const Value& other) const { + return func_op("eq", {op, other.op}); +} +Value Value::operator!=(const Value& other) const { + return func_op("ne", {op, other.op}); +} +Value Value::operator<(const Value& other) const { + return func_op("lt", {op, other.op}); +} +Value Value::operator<=(const Value& other) const { + return func_op("le", {op, other.op}); +} +Value Value::operator>(const Value& other) const { + return func_op("gt", {op, other.op}); +} +Value Value::operator>=(const Value& other) const { + return func_op("ge", {op, other.op}); +} +Value Value::operator<<(const Value& other) const { + return func_op("shl", {op, other.op}); +} +Value Value::operator>>(const Value& other) const { + return func_op("shr", {op, other.op}); +} + +Value Value::operator&&(const Value& other) const { + return func_op("land", {op, other.op}); +} +Value Value::operator||(const Value& other) const { + return func_op("lor", {op, other.op}); +} +Value Value::operator!() const { + return func_op("lnot", {op}); +} + +Value Value::operator-() const { + return func_op("neg", {op}); +} +Value Value::operator+() const { + return func_op("pos", {op}); +} +Value Value::operator~() const { + return func_op("not", {op}); +} + +Value Value::operator[](int index) const { + return unpack_tuple(*this, index); } +Value Value::operator[](const std::vector& indices) const { + return load_at_index(*this, indices); +} + +} // namespace TensorFrost -} \ No newline at end of file From b38baf0e4f92ca2e6406f5bc6b16dcd8af0df2e9 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz Date: Wed, 11 Jun 2025 00:04:12 +0200 Subject: [PATCH 14/44] Unused operation removal --- Python/pyproject.toml | 2 +- TensorFrost/include/Compiler/Operation.h | 8 ++-- .../include/Compiler/OperationBlocks.h | 3 ++ TensorFrost/src/Compiler/ExecutionContext.cpp | 2 +- .../src/Compiler/OperationArguments.cpp | 1 + TensorFrost/src/Compiler/OperationBlocks.cpp | 44 +++++++++++++++++++ TensorFrost/src/Compiler/Printer.cpp | 8 ++-- TensorFrost/src/Compiler/TFProgram.cpp | 27 +++--------- TensorFrost/src/Compiler/Value.cpp | 4 +- 9 files changed, 66 insertions(+), 33 deletions(-) diff --git a/Python/pyproject.toml b/Python/pyproject.toml index 2d9c556b..55af72a3 100644 --- a/Python/pyproject.toml +++ b/Python/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "TensorFrost" -version = "0.8.0.dev0" +version = "2.0.0.dev0" description = "A static optimizing tensor compiler with a Python frontend" authors = [{name = "Mykhailo Moroz", email = "michael08840884@gmail.com"}] requires-python = ">=3.7" diff --git a/TensorFrost/include/Compiler/Operation.h b/TensorFrost/include/Compiler/Operation.h index 934b40a0..1224958b 100644 --- a/TensorFrost/include/Compiler/Operation.h +++ b/TensorFrost/include/Compiler/Operation.h @@ -10,16 +10,16 @@ namespace TensorFrost { struct Op { - OpBlock* parent_block = nullptr; - - size_t index = 0; //might not be up to date std::string opcode; - std::string varname; std::unique_ptr args; AttributeMap attributes; TFDataFormat type; std::vector> blocks; + OpBlock* parent_block = nullptr; + size_t index = 0; //might not be up to date + std::string varname; + Op(std::string op_name); OpBlock* NewBlock(); OpBlock& GetBlock(int index = 0); diff --git a/TensorFrost/include/Compiler/OperationBlocks.h b/TensorFrost/include/Compiler/OperationBlocks.h index ba76cf87..119af204 100644 --- a/TensorFrost/include/Compiler/OperationBlocks.h +++ b/TensorFrost/include/Compiler/OperationBlocks.h @@ -28,6 +28,7 @@ struct OpBlock { Iterator& prev(); Iterator& insert_after(std::unique_ptr op); Iterator& insert_before(std::unique_ptr op); + Iterator& remove(); OpBlock* parent() const { return parent_; } @@ -41,4 +42,6 @@ struct OpBlock { }; void ApplyOpTransform(OpBlock& block, const std::function& transform); +void IterateOver(OpBlock &block, const std::function &transform); +std::set GetDependencies(std::vector ops); } diff --git a/TensorFrost/src/Compiler/ExecutionContext.cpp b/TensorFrost/src/Compiler/ExecutionContext.cpp index a0688a07..15540cf1 100644 --- a/TensorFrost/src/Compiler/ExecutionContext.cpp +++ b/TensorFrost/src/Compiler/ExecutionContext.cpp @@ -20,7 +20,7 @@ void ExecutionContext::EndCursor() { Op& ExecutionContext::Add(std::unique_ptr op) { cursor.insert_before(std::move(op)); - Op* new_op = cursor.get_next(); + Op* new_op = *cursor; cursor.next(); // Move cursor to the new op return *new_op; } diff --git a/TensorFrost/src/Compiler/OperationArguments.cpp b/TensorFrost/src/Compiler/OperationArguments.cpp index 7d1711cc..e8b60147 100644 --- a/TensorFrost/src/Compiler/OperationArguments.cpp +++ b/TensorFrost/src/Compiler/OperationArguments.cpp @@ -83,6 +83,7 @@ void ArgumentManager::RemoveType(ArgType type) { for (size_t i = 0; i < args->inputs.size(); ++i) { args->RemoveInput((int)i); } + args->inputs.clear(); } void ArgumentManager::RemoveAll() { diff --git a/TensorFrost/src/Compiler/OperationBlocks.cpp b/TensorFrost/src/Compiler/OperationBlocks.cpp index a9617dde..7abd4b07 100644 --- a/TensorFrost/src/Compiler/OperationBlocks.cpp +++ b/TensorFrost/src/Compiler/OperationBlocks.cpp @@ -35,6 +35,13 @@ OpBlock::Iterator& OpBlock::Iterator::insert_before(std::unique_ptr op) { return *this; } +OpBlock::Iterator& OpBlock::Iterator::remove() { + if (cur_ == list_->end()) return *this; // Nothing to remove + cur_->get()->parent_block = nullptr; // Clear parent block reference + cur_ = list_->erase(cur_); // Remove and update iterator + return *this; +} + bool OpBlock::Iterator::valid() const { return cur_ != list_->end(); } bool OpBlock::Iterator::operator==(const Iterator &o) const { return cur_ == o.cur_; } bool OpBlock::Iterator::operator!=(const Iterator &o) const { return cur_ != o.cur_; } @@ -50,4 +57,41 @@ void ApplyOpTransform(OpBlock &block, const std::function &transform transform(*op); } } + +void IterateOver(OpBlock &block, const std::function &transform) { + for (OpBlock::Iterator it = block.begin(); it.valid(); it.next()) { + for (auto& sub_block : it->blocks) { + IterateOver(*sub_block, transform); + } + transform(it); + } +} + +void ReverseIterateOver(OpBlock &block, const std::function &transform) { + for (OpBlock::Iterator it = block.end(); it.valid(); it.prev()) { + for (auto& sub_block : it->blocks) { + ReverseIterateOver(*sub_block, transform); + } + transform(it); + } +} + +std::set GetDependencies(std::vector ops) { + std::set dependencies; + std::function collect_dependencies = [&](Op* op) { + if (op == nullptr || dependencies.contains(op)) return; // Already processed + dependencies.insert(op); + for (auto& input : op->args->Get(ArgType::Input)->inputs) { + collect_dependencies(input->from); + } + for (auto& input : op->args->Get(ArgType::Index)->inputs) { + collect_dependencies(input->from); + } + collect_dependencies(op->parent_block->parent_op); // Collect dependencies of the parent op + }; + for (Op* op : ops) { + collect_dependencies(op); + } + return dependencies; +} } diff --git a/TensorFrost/src/Compiler/Printer.cpp b/TensorFrost/src/Compiler/Printer.cpp index fb87a357..93ad764d 100644 --- a/TensorFrost/src/Compiler/Printer.cpp +++ b/TensorFrost/src/Compiler/Printer.cpp @@ -63,10 +63,10 @@ void PrintOp(const Op* op, std::ostringstream &os) { if (op->opcode == "const") { os << " = " << op->attributes.at("value"); } else { - // std::string inputs = PrintArguments(op->args->Get(ArgType::Input)->inputs, "", ""); - // std::string index = PrintArguments(op->args->Get(ArgType::Index)->inputs, "index={", "}"); - std::string inputs = "args=" + PrintArguments(op->args->Get(ArgType::Input)); - std::string index = "index=" + PrintArguments(op->args->Get(ArgType::Index)); + std::string inputs = PrintArguments(op->args->Get(ArgType::Input)->inputs, "", ""); + std::string index = PrintArguments(op->args->Get(ArgType::Index)->inputs, "index={", "}"); + // std::string inputs = "args=" + PrintArguments(op->args->Get(ArgType::Input)); + // std::string index = "index=" + PrintArguments(op->args->Get(ArgType::Index)); std::vector attributes; for (const auto& [key, value] : op->attributes) { attributes.push_back(key + ": " + PrintAttribute(value)); diff --git a/TensorFrost/src/Compiler/TFProgram.cpp b/TensorFrost/src/Compiler/TFProgram.cpp index 8eed77e9..4dc917c4 100644 --- a/TensorFrost/src/Compiler/TFProgram.cpp +++ b/TensorFrost/src/Compiler/TFProgram.cpp @@ -10,14 +10,14 @@ TFProgram::TFProgram(std::function, std::vector& used_ops, OpBlock& block) { -// for (auto& op : block.ops) { -// if (op->used_at.empty()) continue; // Skip unused operations -// used_ops.insert(op.get()); -// for (auto& sub_block : op->blocks) { -// FindUsedOps(used_ops, *sub_block); -// } -// } -// } - void TFProgram::RemoveUnused() { - StartExecutionContext(&context); - - ApplyOpTransform(*GetBaseBlock(), [](Op& op) { - // if (op.opcode == "const") return; // Skip constants - // if (op.args->Get(ArgType::Input)->inputs.empty()) return; // Skip operations with no inputs - // if (op.args->Get(ArgType::Output)->inputs.empty()) return; // Skip operations with no outputs - // if (op.used_at.empty()) return; // Skip unused operations + std::set used_ops = GetDependencies(values_to_ops(program_outputs)); + IterateOver(*GetBaseBlock(), [&](OpBlock::Iterator& it) { + if (!used_ops.contains(*it)) { + it.remove(); // Remove unused operations + } }); - EndExecutionContext(); } std::string TFProgram::DebugPrint() const { diff --git a/TensorFrost/src/Compiler/Value.cpp b/TensorFrost/src/Compiler/Value.cpp index f3c4cb24..89c04f50 100644 --- a/TensorFrost/src/Compiler/Value.cpp +++ b/TensorFrost/src/Compiler/Value.cpp @@ -14,9 +14,7 @@ Value::Value(Op* operation) : op(operation) { Value::Value(float value) : op(constant(value).op) {} Value::Value(int value) : op(constant(value).op) {} Value::Value(uint value) : op(constant(value).op) {} -Value::Value(bool value) : op(constant(value).op) { - -} +Value::Value(bool value) : op(constant(value).op) {} std::vector values_to_ops(const std::vector& values) { std::vector ops; From b6627ccccbd343ac0d3df84cc465c7262ac59c4d Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz Date: Sun, 15 Jun 2025 18:54:01 +0200 Subject: [PATCH 15/44] Improved compiler --- .idea/editor.xml | 352 +++--------------- TensorFrost/CMakeLists.txt | 4 + TensorFrost/PybindModule.cpp | 16 +- TensorFrost/include/Compiler/Common.h | 1 + .../include/Compiler/ExecutionContext.h | 3 +- TensorFrost/include/Compiler/Operation.h | 30 +- .../include/Compiler/OperationArguments.h | 1 + .../include/Compiler/OperationBlocks.h | 6 +- .../include/Compiler/OperationRegistry.h | 6 +- TensorFrost/include/Compiler/Overloads.h | 7 + TensorFrost/include/Compiler/TFProgram.h | 1 + TensorFrost/include/Compiler/Value.h | 20 +- TensorFrost/src/Compiler/ExecutionContext.cpp | 24 +- TensorFrost/src/Compiler/Operation.cpp | 61 ++- .../src/Compiler/OperationArguments.cpp | 9 + TensorFrost/src/Compiler/OperationBlocks.cpp | 64 +--- .../src/Compiler/OperationRegistry.cpp | 43 ++- TensorFrost/src/Compiler/Overloads.cpp | 33 ++ TensorFrost/src/Compiler/Printer.cpp | 45 ++- TensorFrost/src/Compiler/TFProgram.cpp | 14 +- TensorFrost/src/Compiler/Value.cpp | 99 +++-- 21 files changed, 414 insertions(+), 425 deletions(-) diff --git a/.idea/editor.xml b/.idea/editor.xml index 6692539a..4d514864 100644 --- a/.idea/editor.xml +++ b/.idea/editor.xml @@ -1,302 +1,60 @@ - \ No newline at end of file diff --git a/TensorFrost/CMakeLists.txt b/TensorFrost/CMakeLists.txt index 1fa22052..384bd775 100644 --- a/TensorFrost/CMakeLists.txt +++ b/TensorFrost/CMakeLists.txt @@ -41,6 +41,10 @@ target_link_libraries(TensorFrost PRIVATE glad_gl_core_46) glad_add_library(glad_vulkan_12 REPRODUCIBLE LOADER API vulkan=1.2) target_link_libraries(TensorFrost PRIVATE glad_vulkan_12) +if (MSVC) + target_compile_options(TensorFrost PRIVATE /wd4804 /wd4805 /wd4018) +endif() + # ---- ImGui ---- target_include_directories(TensorFrost PRIVATE ${CMAKE_SOURCE_DIR}/external/imgui diff --git a/TensorFrost/PybindModule.cpp b/TensorFrost/PybindModule.cpp index c0d9c062..8b597c26 100644 --- a/TensorFrost/PybindModule.cpp +++ b/TensorFrost/PybindModule.cpp @@ -138,14 +138,26 @@ PYBIND11_MODULE(TensorFrost, m) { std::vector outputs; Value a = 5; Value b = 10; + Value f = 2.5f; + Value g = 3.5f; Value c = a + b * 3; Value mem = memory({a, b, c}, TFTypeFloat32); inputs.push_back(mem); + vmap({a, b}, [&](Value ids0) { + Value something = tofloat(mem * sin(f + g)); + }); vmap({a, b, c}, [&](Value ids0) { - Value imem = toint(mem); + Value imem = toint(mem * sin(f + g)); Value d = c + b + ids0[0] * imem; + Value m0, m1; + if_cond(d > 0, [&]() { + m0 = d * c * imem; + }, [&]() { + m1 = d * c * imem + 1; + }); + Value result = phi({m0, m1}); vmap({c, c}, [&](Value ids1) { - Value m = d * c * imem[{ids1[1], ids1[0], ids0[0]}]; + Value m = result * imem[{ids1[1], ids1[0], ids0[0]}]; outputs.push_back(m); }); }); diff --git a/TensorFrost/include/Compiler/Common.h b/TensorFrost/include/Compiler/Common.h index 0f01e8e3..95a33720 100644 --- a/TensorFrost/include/Compiler/Common.h +++ b/TensorFrost/include/Compiler/Common.h @@ -100,6 +100,7 @@ class OpBlockIterator; struct ArgumentManager; struct Argument; class Value; +struct Shape; using Attribute = std::variant; using AttributeMap = std::unordered_map; diff --git a/TensorFrost/include/Compiler/ExecutionContext.h b/TensorFrost/include/Compiler/ExecutionContext.h index 15ec02e8..9cfa9e46 100644 --- a/TensorFrost/include/Compiler/ExecutionContext.h +++ b/TensorFrost/include/Compiler/ExecutionContext.h @@ -7,8 +7,7 @@ namespace TensorFrost { struct ExecutionContext { std::unique_ptr base_block; - OpBlock::Iterator cursor; - std::stack stack; + std::stack cursor_stack; ExecutionContext(); void BeginCursor(OpBlock::Iterator it); diff --git a/TensorFrost/include/Compiler/Operation.h b/TensorFrost/include/Compiler/Operation.h index 1224958b..52c2b689 100644 --- a/TensorFrost/include/Compiler/Operation.h +++ b/TensorFrost/include/Compiler/Operation.h @@ -26,9 +26,37 @@ struct Op { void AddAttribute(const std::string& name, const Attribute& value); void ChangeAttribute(const std::string& name, const Attribute& value); - void GetAttribute(const std::string& name, Attribute& value) const; + + Attribute GetAttribute(const std::string &name) const; + + bool Compare(const Op& other) const; }; +void ApplyOpTransform(OpBlock& block, const std::function& transform); +void IterateOver(OpBlock &block, const std::function &transform); +std::set CollectDependencies(std::vector ops); + +template +State IterateWithLocalState(OpBlock &block, Fn &&f) { + State current = {}; + for (auto it = block.begin(); it.valid(); it.next()) { + std::vector kids; + for (auto &sb : it->blocks) + kids.push_back(IterateWithLocalState(*sb, f)); + f(it, current, kids); + } + return current; +} + +template +void IterateWithGlobalState(OpBlock &block, State& current, Fn &&f) { + for (auto it = block.begin(); it.valid(); it.next()) { + std::vector kids; + for (auto &sb : it->blocks) + kids.push_back(IterateWithGlobalState(*sb,current, f)); + f(it, current, kids); + } +} } // namespace ir diff --git a/TensorFrost/include/Compiler/OperationArguments.h b/TensorFrost/include/Compiler/OperationArguments.h index ae2805c8..ca5f946f 100644 --- a/TensorFrost/include/Compiler/OperationArguments.h +++ b/TensorFrost/include/Compiler/OperationArguments.h @@ -38,6 +38,7 @@ struct ArgumentManager { Arguments* Get(ArgType type) const; Arguments* operator[](ArgType type) const; + std::vector GetInputs(ArgType type) const; }; } \ No newline at end of file diff --git a/TensorFrost/include/Compiler/OperationBlocks.h b/TensorFrost/include/Compiler/OperationBlocks.h index 119af204..cd0d4f34 100644 --- a/TensorFrost/include/Compiler/OperationBlocks.h +++ b/TensorFrost/include/Compiler/OperationBlocks.h @@ -30,6 +30,9 @@ struct OpBlock { Iterator& insert_before(std::unique_ptr op); Iterator& remove(); + Iterator& move_before(Iterator other); + Iterator& move_range_before(Iterator other_start, Iterator other_end); + OpBlock* parent() const { return parent_; } bool valid() const; @@ -41,7 +44,4 @@ struct OpBlock { Iterator end(); }; -void ApplyOpTransform(OpBlock& block, const std::function& transform); -void IterateOver(OpBlock &block, const std::function &transform); -std::set GetDependencies(std::vector ops); } diff --git a/TensorFrost/include/Compiler/OperationRegistry.h b/TensorFrost/include/Compiler/OperationRegistry.h index 5822365c..d522130d 100644 --- a/TensorFrost/include/Compiler/OperationRegistry.h +++ b/TensorFrost/include/Compiler/OperationRegistry.h @@ -18,11 +18,13 @@ enum class OpClass { Constant, TernaryOperator, Memory, + Phi, None, }; enum class OpProp { - ShapeArgs, + Variadic, + HasShape, Load, Store, MemoryOp, @@ -73,7 +75,7 @@ struct OpSpec { OpClass op_class = OpClass::None; std::set props; int blocks = 0; - FoldFn constant_fold = nullptr; + FoldFn const_fold = nullptr; TFDataFormat GetOutputType(const std::vector& args) const; }; diff --git a/TensorFrost/include/Compiler/Overloads.h b/TensorFrost/include/Compiler/Overloads.h index 1c7d640f..b1df8e38 100644 --- a/TensorFrost/include/Compiler/Overloads.h +++ b/TensorFrost/include/Compiler/Overloads.h @@ -14,9 +14,16 @@ Value unpack_tuple(Value x, int index = 0); Value vmap(std::vector shape, std::function body); Value memory(std::vector shape, TFDataFormat type); Value load_at_index(Value mem, std::vector indices); +void if_cond(Value cond, std::function body_true, std::function body_false = nullptr); +Value loop(Value start, Value end, Value step, std::function body); +Value phi(std::vector inputs); inline Value toint(Value x) { return func_op("toint", {x}); } inline Value tofloat(Value x) { return func_op("tofloat", {x}); } inline Value touint(Value x) { return func_op("touint", {x}); } inline Value tobool(Value x) { return func_op("tobool", {x}); } + +inline Value sin(Value x) { return func_op("sin", {x}); } +inline Value cos(Value x) { return func_op("cos", {x}); } +inline Value tan(Value x) { return func_op("tan", {x}); } } diff --git a/TensorFrost/include/Compiler/TFProgram.h b/TensorFrost/include/Compiler/TFProgram.h index c55ffb75..fed163df 100644 --- a/TensorFrost/include/Compiler/TFProgram.h +++ b/TensorFrost/include/Compiler/TFProgram.h @@ -15,6 +15,7 @@ class TFProgram { void Compile(); void ConstantFold(); void RemoveUnused(); + void CombineVmapDepthwise(); std::string DebugPrint() const; }; diff --git a/TensorFrost/include/Compiler/Value.h b/TensorFrost/include/Compiler/Value.h index 4b35ee01..dba399fb 100644 --- a/TensorFrost/include/Compiler/Value.h +++ b/TensorFrost/include/Compiler/Value.h @@ -8,7 +8,9 @@ class Value { public: Op* op = nullptr; + Value() = default; Value(Op* operation); + Value(const Op* operation); Value(float value); Value(int value); Value(uint value); @@ -41,8 +43,24 @@ class Value { Value operator-() const; Value operator+() const; Value operator~() const; + + bool Compare(const Value& other) const; }; std::vector values_to_ops(const std::vector& values); std::vector ops_to_values(const std::vector& ops); -} \ No newline at end of file + +struct Shape { + std::vector dimensions; + Shape(std::vector dims) : dimensions(std::move(dims)) {} + Shape(std::initializer_list dims) : dimensions(dims) {} + Shape() = default; + Shape(const Shape& other) : dimensions(other.dimensions) {} + + void AddDimension(const Value& dim); + void AddDimensions(const std::vector& dims); + bool Broadcastable(const Shape& other) const; +}; + +Shape ComputeShape(Value x); +} diff --git a/TensorFrost/src/Compiler/ExecutionContext.cpp b/TensorFrost/src/Compiler/ExecutionContext.cpp index 15540cf1..6a5b2c4d 100644 --- a/TensorFrost/src/Compiler/ExecutionContext.cpp +++ b/TensorFrost/src/Compiler/ExecutionContext.cpp @@ -3,31 +3,31 @@ #include "Compiler/OperationBlocks.h" namespace TensorFrost { -ExecutionContext::ExecutionContext(): base_block(std::make_unique()), cursor(base_block->begin()) {} +ExecutionContext::ExecutionContext(): base_block(std::make_unique()) { + cursor_stack.push(base_block->begin()); +} void ExecutionContext::BeginCursor(OpBlock::Iterator it) { - stack.push(&cursor); - cursor = it; + cursor_stack.push(it); } void ExecutionContext::EndCursor() { - if (stack.empty()) { + if (cursor_stack.empty()) { throw std::runtime_error("This is the last cursor, cannot end it"); } - cursor = *stack.top(); - stack.pop(); + cursor_stack.pop(); } Op& ExecutionContext::Add(std::unique_ptr op) { - cursor.insert_before(std::move(op)); - Op* new_op = *cursor; - cursor.next(); // Move cursor to the new op + cursor_stack.top().insert_before(std::move(op)); + Op* new_op = *cursor_stack.top(); + cursor_stack.top().next(); // Move the cursor to the new op return *new_op; } Op& ExecutionContext::AddBeforeCursor(std::unique_ptr op) { - cursor.insert_before(std::move(op)); - return **cursor; + cursor_stack.top().insert_before(std::move(op)); + return **cursor_stack.top(); } ExecutionContext* current_context = nullptr; @@ -57,7 +57,7 @@ OpBlock* GetCurrentBlock() { if (!current_context) { throw std::runtime_error("No execution context available"); } - return current_context->cursor.parent(); + return current_context->cursor_stack.top().parent(); } void BeginCursor(OpBlock::Iterator it) { diff --git a/TensorFrost/src/Compiler/Operation.cpp b/TensorFrost/src/Compiler/Operation.cpp index efe95b59..d1425bfb 100644 --- a/TensorFrost/src/Compiler/Operation.cpp +++ b/TensorFrost/src/Compiler/Operation.cpp @@ -32,11 +32,68 @@ void Op::ChangeAttribute(const std::string &name, const Attribute &value) { attributes[name] = value; } -void Op::GetAttribute(const std::string &name, Attribute &value) const { +Attribute Op::GetAttribute(const std::string &name) const { auto it = attributes.find(name); if (it == attributes.end()) { throw std::runtime_error("Attribute '" + name + "' not found in operation '" + opcode + "'"); } - value = it->second; + return it->second; +} + +bool Op::Compare(const Op &other) const { + bool both_const = (opcode == "const" && other.opcode == "const"); + if (both_const) { + // Compare constant values directly + Attribute this_value = GetAttribute("value"); + Attribute other_value = other.GetAttribute("value"); + return (this_value == other_value); + } + return false; // TODO: Implement more complex comparison logic for non-constant operations +} + +void ApplyOpTransform(OpBlock &block, const std::function &transform) { + for (auto& op : block.ops) { + for (auto& sub_block : op->blocks) { + ApplyOpTransform(*sub_block, transform); + } + transform(*op); + } +} + +void IterateOver(OpBlock &block, const std::function &transform) { + for (OpBlock::Iterator it = block.begin(); it.valid(); it.next()) { + for (auto& sub_block : it->blocks) { + IterateOver(*sub_block, transform); + } + transform(it); + } +} + +void ReverseIterateOver(OpBlock &block, const std::function &transform) { + for (OpBlock::Iterator it = block.end(); it.valid(); it.prev()) { + for (auto& sub_block : it->blocks) { + ReverseIterateOver(*sub_block, transform); + } + transform(it); + } +} + +std::set CollectDependencies(std::vector ops) { + std::set dependencies; + std::function collect_dependencies = [&](Op* op) { + if (op == nullptr || dependencies.contains(op)) return; // Already processed + dependencies.insert(op); + for (auto& input : op->args->Get(ArgType::Input)->inputs) { + collect_dependencies(input->from); + } + for (auto& input : op->args->Get(ArgType::Index)->inputs) { + collect_dependencies(input->from); + } + collect_dependencies(op->parent_block->parent_op); // Parent depends on this operation + }; + for (Op* op : ops) { + collect_dependencies(op); + } + return dependencies; } } diff --git a/TensorFrost/src/Compiler/OperationArguments.cpp b/TensorFrost/src/Compiler/OperationArguments.cpp index e8b60147..7972a3f1 100644 --- a/TensorFrost/src/Compiler/OperationArguments.cpp +++ b/TensorFrost/src/Compiler/OperationArguments.cpp @@ -99,4 +99,13 @@ Arguments* ArgumentManager::Get(ArgType type) const { Arguments * ArgumentManager::operator[](ArgType type) const { return Get(type); } + +std::vector ArgumentManager::GetInputs(ArgType type) const { + auto *args = Get(type); + std::vector inputs; + for (const auto& arg : args->inputs) { + inputs.push_back(arg->from); + } + return inputs; +} } diff --git a/TensorFrost/src/Compiler/OperationBlocks.cpp b/TensorFrost/src/Compiler/OperationBlocks.cpp index 7abd4b07..ebc7cfc3 100644 --- a/TensorFrost/src/Compiler/OperationBlocks.cpp +++ b/TensorFrost/src/Compiler/OperationBlocks.cpp @@ -42,6 +42,25 @@ OpBlock::Iterator& OpBlock::Iterator::remove() { return *this; } +OpBlock::Iterator& OpBlock::Iterator::move_before(Iterator other) +{ + auto pos = cur_; + list_->splice(pos, *other.list_, other.cur_); + cur_ = std::prev(pos); + cur_->get()->parent_block = parent_; + return *this; +} + +OpBlock::Iterator& OpBlock::Iterator::move_range_before(Iterator other_start, Iterator other_end) +{ + auto pos = cur_; + for(auto it = other_start.cur_; it != other_end.cur_; ++it) + it->get()->parent_block = parent_; + list_->splice(pos, *other_start.list_, other_start.cur_, other_end.cur_); + cur_ = std::prev(pos); + return *this; +} + bool OpBlock::Iterator::valid() const { return cur_ != list_->end(); } bool OpBlock::Iterator::operator==(const Iterator &o) const { return cur_ == o.cur_; } bool OpBlock::Iterator::operator!=(const Iterator &o) const { return cur_ != o.cur_; } @@ -49,49 +68,4 @@ bool OpBlock::Iterator::operator!=(const Iterator &o) const { return cur_ != o.c OpBlock::Iterator OpBlock::begin() { return Iterator(this, ops.begin()); } OpBlock::Iterator OpBlock::end() { return Iterator(this, ops.end()); } -void ApplyOpTransform(OpBlock &block, const std::function &transform) { - for (auto& op : block.ops) { - for (auto& sub_block : op->blocks) { - ApplyOpTransform(*sub_block, transform); - } - transform(*op); - } -} - -void IterateOver(OpBlock &block, const std::function &transform) { - for (OpBlock::Iterator it = block.begin(); it.valid(); it.next()) { - for (auto& sub_block : it->blocks) { - IterateOver(*sub_block, transform); - } - transform(it); - } -} - -void ReverseIterateOver(OpBlock &block, const std::function &transform) { - for (OpBlock::Iterator it = block.end(); it.valid(); it.prev()) { - for (auto& sub_block : it->blocks) { - ReverseIterateOver(*sub_block, transform); - } - transform(it); - } -} - -std::set GetDependencies(std::vector ops) { - std::set dependencies; - std::function collect_dependencies = [&](Op* op) { - if (op == nullptr || dependencies.contains(op)) return; // Already processed - dependencies.insert(op); - for (auto& input : op->args->Get(ArgType::Input)->inputs) { - collect_dependencies(input->from); - } - for (auto& input : op->args->Get(ArgType::Index)->inputs) { - collect_dependencies(input->from); - } - collect_dependencies(op->parent_block->parent_op); // Collect dependencies of the parent op - }; - for (Op* op : ops) { - collect_dependencies(op); - } - return dependencies; -} } diff --git a/TensorFrost/src/Compiler/OperationRegistry.cpp b/TensorFrost/src/Compiler/OperationRegistry.cpp index 37b26ca2..b6906b35 100644 --- a/TensorFrost/src/Compiler/OperationRegistry.cpp +++ b/TensorFrost/src/Compiler/OperationRegistry.cpp @@ -4,7 +4,7 @@ using namespace std; namespace TensorFrost { TFDataFormat OpSpec::GetOutputType(const std::vector &args) const { - if (props.contains(OpProp::ShapeArgs) || args.empty()) { + if (props.contains(OpProp::Variadic) || args.empty()) { return overloads.find({})->second; } auto it = overloads.find(args); @@ -88,29 +88,40 @@ make_fold2([](auto a, auto b) { \ make_fold3([](auto a, auto b, auto c) { return op(a, b, c); }) vector default_operations = { - DEF_OP("memory", "f(); u(); i(); b(); tuple()", OpClass::Memory, .props = {OpProp::ShapeArgs}), + DEF_OP("memory", "f(); u(); i(); b(); tuple()", OpClass::Memory, .props = {OpProp::Variadic, OpProp::HasShape}), DEF_OP("load", "f(f); u(u); i(i); b(b)", OpClass::Function, .props = {OpProp::Load, OpProp::MemoryOp}), DEF_OP("store", "f(f); u(u); i(i); b(b)", OpClass::Function, .props = {OpProp::Store, OpProp::MemoryOp}), DEF_OP("const", "f(); u(); i(); b(); tuple()", OpClass::Constant), DEF_OP("copy", "f(f); u(u); i(i); b(b)", OpClass::Copy), - DEF_OP("add", "f(f,f); u(u,u); i(i,i)", OpClass::Operator, .constant_fold = BIN_OP_FOLD(+)), - DEF_OP("sub", "f(f,f); u(u,u); i(i,i)", OpClass::Operator, .constant_fold = BIN_OP_FOLD(-)), - DEF_OP("mul", "f(f,f); u(u,u); i(i,i)", OpClass::Operator, .constant_fold = BIN_OP_FOLD(*)), - DEF_OP("div", "f(f,f); u(u,u); i(i,i)", OpClass::Operator, .constant_fold = BIN_OP_FOLD(/)), - DEF_OP("sin", "f(f); u(u); i(i)", OpClass::UnaryOperator, .constant_fold = UN_FUNC_FOLD(std::sinf)), - DEF_OP("cos", "f(f); u(u); i(i)", OpClass::UnaryOperator, .constant_fold = UN_FUNC_FOLD(std::cosf)), - DEF_OP("tan", "f(f); u(u); i(i)", OpClass::UnaryOperator, .constant_fold = UN_FUNC_FOLD(std::tanf)), - - - DEF_OP("tofloat", "f(i); f(u); f(b)", OpClass::TypeCast), - DEF_OP("toint", "i(f); i(u); i(b)", OpClass::TypeCast), - DEF_OP("touint", "u(f); u(i); u(b)", OpClass::TypeCast), - DEF_OP("tobool", "b(f); b(i); b(u)", OpClass::TypeCast), + DEF_OP("add", "f(f,f); u(u,u); i(i,i)", OpClass::Operator, .const_fold = BIN_OP_FOLD(+)), + DEF_OP("sub", "f(f,f); u(u,u); i(i,i)", OpClass::Operator, .const_fold = BIN_OP_FOLD(-)), + DEF_OP("mul", "f(f,f); u(u,u); i(i,i)", OpClass::Operator, .const_fold = BIN_OP_FOLD(*)), + DEF_OP("div", "f(f,f); u(u,u); i(i,i)", OpClass::Operator, .const_fold = BIN_OP_FOLD(/)), + DEF_OP("sin", "f(f); u(u); i(i)", OpClass::UnaryOperator, .const_fold = UN_FUNC_FOLD(std::sinf)), + DEF_OP("cos", "f(f); u(u); i(i)", OpClass::UnaryOperator, .const_fold = UN_FUNC_FOLD(std::cosf)), + DEF_OP("tan", "f(f); u(u); i(i)", OpClass::UnaryOperator, .const_fold = UN_FUNC_FOLD(std::tanf)), + + DEF_OP("eq", "b(f,f); b(u,u); b(i,i)", OpClass::Operator, .const_fold = BIN_FUNC_FOLD(std::equal_to<>())), + DEF_OP("ne", "b(f,f); b(u,u); b(i,i)", OpClass::Operator, .const_fold = BIN_FUNC_FOLD(std::not_equal_to<>())), + DEF_OP("lt", "b(f,f); b(u,u); b(i,i)", OpClass::Operator, .const_fold = BIN_FUNC_FOLD(std::less<>())), + DEF_OP("le", "b(f,f); b(u,u); b(i,i)", OpClass::Operator, .const_fold = BIN_FUNC_FOLD(std::less_equal<>())), + DEF_OP("gt", "b(f,f); b(u,u); b(i,i)", OpClass::Operator, .const_fold = BIN_FUNC_FOLD(std::greater<>())), + DEF_OP("ge", "b(f,f); b(u,u); b(i,i)", OpClass::Operator, .const_fold = BIN_FUNC_FOLD(std::greater_equal<>())), + + DEF_OP("tofloat", "f(f); f(i); f(u); f(b)", OpClass::Function, .const_fold = UN_FUNC_FOLD(static_cast)), + DEF_OP("toint", "i(f); i(i); i(u); i(b)", OpClass::Function, .const_fold = UN_FUNC_FOLD(static_cast)), + DEF_OP("touint", "u(f); u(i); u(u); u(b)", OpClass::Function, .const_fold = UN_FUNC_FOLD(static_cast)), + DEF_OP("tobool", "b(f); b(i); b(u); b(b)", OpClass::Function, .const_fold = UN_FUNC_FOLD(static_cast)), DEF_OP("unpack_tuple_int", "i(tuple)", OpClass::Function), - DEF_OP("vmap", "tuple()", OpClass::Parallel, .props = {OpProp::ShapeArgs}, .blocks = 1), + // Operations with blocks + DEF_OP("vmap", "tuple()", OpClass::Parallel, .props = {OpProp::Variadic, OpProp::HasShape}, .blocks = 1), + DEF_OP("if_cond", "void(b)", OpClass::Function, .blocks = 2), + DEF_OP("loop", "i(i,i,i)", OpClass::Function, .blocks = 1), + + DEF_OP("phi", "f(); u(); i(); b()", OpClass::Phi, .props = {OpProp::Variadic}), }; std::unordered_map> CreateOperationRegistry() { diff --git a/TensorFrost/src/Compiler/Overloads.cpp b/TensorFrost/src/Compiler/Overloads.cpp index 9fd0fac2..b8e6df70 100644 --- a/TensorFrost/src/Compiler/Overloads.cpp +++ b/TensorFrost/src/Compiler/Overloads.cpp @@ -71,4 +71,37 @@ Value load_at_index(Value mem, std::vector indices) { } return make_op("load", indices, {mem}); } + +void if_cond(Value cond, std::function body_true, std::function body_false) { + if (cond.op->type != TFTypeBool32) { + throw std::runtime_error("Condition must be a boolean value"); + } + Value if_op = func_op("if_cond", {cond}); + GetContext()->BeginCursor(if_op.op->GetBlock(0).begin()); + body_true(); + GetContext()->EndCursor(); + if (body_false) { + GetContext()->BeginCursor(if_op.op->GetBlock(1).begin()); + body_false(); + GetContext()->EndCursor(); + } +} + +Value loop(Value start, Value end, Value step, std::function body) { + if (start.op->type != TFTypeInt32 || end.op->type != TFTypeInt32 || step.op->type != TFTypeInt32) { + throw std::runtime_error("Loop indices must be of type int32"); + } + Value loop_op = func_op("loop", {start, end, step}); + GetContext()->BeginCursor(loop_op.op->GetBlock().begin()); + body(loop_op); + GetContext()->EndCursor(); + return loop_op; +} + +Value phi(std::vector inputs) { + Value phi_op = func_op("phi", inputs); + phi_op.op->type = inputs.empty() ? TFTypeNone : inputs[0].op->type; // Set type based on first input + return phi_op; +} + } diff --git a/TensorFrost/src/Compiler/Printer.cpp b/TensorFrost/src/Compiler/Printer.cpp index 93ad764d..29045391 100644 --- a/TensorFrost/src/Compiler/Printer.cpp +++ b/TensorFrost/src/Compiler/Printer.cpp @@ -38,6 +38,14 @@ std::string PrintArguments(const auto_vector>& vec, st return PrintArray(StringifyArguments(vec), begin, end); } +std::string PrintShape(const Shape& shape) { + std::vector dims; + for (const auto& dim : shape.dimensions) { + dims.push_back(VariableName(dim.op)); + } + return PrintArray(dims, "[", "]", ", "); +} + std::string PrintArguments(const Arguments* args) { if (!args) return ""; std::vector inputs = StringifyArguments(args->inputs); @@ -58,10 +66,12 @@ std::string PrintAttribute(Attribute attr) { return oss.str(); } -void PrintOp(const Op* op, std::ostringstream &os) { +std::string PrintOp(const Op* op) { + std::ostringstream os; os << ToString(op->type) << " " << op->varname; if (op->opcode == "const") { - os << " = " << op->attributes.at("value"); + return ""; + //os << " = " << op->attributes.at("value"); } else { std::string inputs = PrintArguments(op->args->Get(ArgType::Input)->inputs, "", ""); std::string index = PrintArguments(op->args->Get(ArgType::Index)->inputs, "index={", "}"); @@ -73,8 +83,11 @@ void PrintOp(const Op* op, std::ostringstream &os) { } std::string attributes_str = PrintArray(attributes, "{", "}"); - os << " = " << op->opcode << "(" << PrintArray({inputs, index, attributes_str}) << ")"; + std::string shape_str = "";// PrintShape(ComputeShape(Value(op))); + + os << shape_str << " = " << op->opcode << "(" << PrintArray({inputs, index, attributes_str}) << ")"; } + return os.str(); } std::string AddIndent(const std::string& str, int indent) { @@ -88,21 +101,19 @@ std::string AddIndent(const std::string& str, int indent) { return indented; } - -std::string PrintBlock(OpBlock &block) { - auto oss = std::ostringstream(); - for (auto it = block.begin(); it.valid(); it.next()) { - PrintOp(*it, oss); - if(it->blocks.size() > 0) { - std::vector blocks; - for (auto& sub_block : it->blocks) { - blocks.push_back(AddIndent(PrintBlock(*sub_block.get()), 4)); - } - oss << PrintArray(blocks, " { \n", "}", "} { \n"); +std::string PrintBlock(OpBlock &root) { + return IterateWithLocalState(root, [](OpBlock::Iterator &it, std::string& current, const std::vector &kids) { + std::string result = PrintOp(*it); + if(result.empty()) return; + if (!kids.empty()) { + std::vector indented; + indented.reserve(kids.size()); + for (auto &s : kids) indented.push_back(AddIndent(s, 4)); + result += PrintArray(indented, " { \n", "}", "} else { \n"); } - oss << "\n"; - } - return oss.str(); + result += '\n'; + current += result; + }); } void AssignVariableNames(OpBlock &block) { diff --git a/TensorFrost/src/Compiler/TFProgram.cpp b/TensorFrost/src/Compiler/TFProgram.cpp index 4dc917c4..0f3a8e6e 100644 --- a/TensorFrost/src/Compiler/TFProgram.cpp +++ b/TensorFrost/src/Compiler/TFProgram.cpp @@ -24,7 +24,7 @@ void TFProgram::Compile() { void TFProgram::ConstantFold() { ApplyOpTransform(*GetBaseBlock(), [](Op& op) { OpSpec* spec = GetOpSpec(op.opcode); - if(!spec->constant_fold) return; // Skip if no constant folding is defined for this operation + if(!spec->const_fold) return; // Skip if no constant folding is defined for this operation AttributeVector inputs; for (const auto& arg : op.args->Get(ArgType::Input)->inputs) { if(!arg->from->attributes.contains("value")) { @@ -32,7 +32,7 @@ void TFProgram::ConstantFold() { } inputs.push_back(arg->from->attributes.at("value")); } - Attribute result = spec->constant_fold(inputs); + Attribute result = spec->const_fold(inputs); op.attributes["value"] = result; // Set the result as a constant value op.opcode = "const"; // Change the opcode to constant op.args->RemoveAll(); // Clear all arguments @@ -40,7 +40,7 @@ void TFProgram::ConstantFold() { } void TFProgram::RemoveUnused() { - std::set used_ops = GetDependencies(values_to_ops(program_outputs)); + std::set used_ops = CollectDependencies(values_to_ops(program_outputs)); IterateOver(*GetBaseBlock(), [&](OpBlock::Iterator& it) { if (!used_ops.contains(*it)) { it.remove(); // Remove unused operations @@ -48,6 +48,14 @@ void TFProgram::RemoveUnused() { }); } +// Converts multilevel vmap operations into a sequence of vmaps with concatenated shape +void TFProgram::CombineVmapDepthwise() { + IterateOver(*GetBaseBlock(), [&](OpBlock::Iterator& it) { + static OpBlock* last_block = nullptr; + OpBlock* current_block = it->parent_block; + }); +} + std::string TFProgram::DebugPrint() const { std::string program_header = "TFProgram(inputs=" + PrintArray(TransformVector(values_to_ops(program_inputs), VariableName), "[", "]") + ") {\n"; std::string inner_code = PrintBlock(*context.base_block); diff --git a/TensorFrost/src/Compiler/Value.cpp b/TensorFrost/src/Compiler/Value.cpp index 89c04f50..b9e302a7 100644 --- a/TensorFrost/src/Compiler/Value.cpp +++ b/TensorFrost/src/Compiler/Value.cpp @@ -11,36 +11,18 @@ Value::Value(Op* operation) : op(operation) { } } +Value::Value(const Op *operation) { + if (!operation) { + throw std::runtime_error("Value cannot be constructed with a null Op pointer"); + } + op = const_cast(operation); +} + Value::Value(float value) : op(constant(value).op) {} Value::Value(int value) : op(constant(value).op) {} Value::Value(uint value) : op(constant(value).op) {} Value::Value(bool value) : op(constant(value).op) {} -std::vector values_to_ops(const std::vector& values) { - std::vector ops; - ops.reserve(values.size()); - for (const auto& value : values) { - if (value.op) { - ops.push_back(value.op); - } else { - throw std::runtime_error("Value contains a null Op pointer"); - } - } - return ops; -} - -std::vector ops_to_values(const std::vector& ops) { - std::vector values; - values.reserve(ops.size()); - for (const auto& op : ops) { - if (op) { - values.emplace_back(op); - } else { - throw std::runtime_error("Op pointer in vector is null"); - } - } - return values; -} Value Value::operator+(const Value& other) const { return func_op("add", {op, other.op}); @@ -102,6 +84,11 @@ Value Value::operator~() const { return func_op("not", {op}); } +bool Value::Compare(const Value &other) const { + if(op == other.op) return true; + return op->Compare(*other.op); +} + Value Value::operator[](int index) const { return unpack_tuple(*this, index); } @@ -109,5 +96,67 @@ Value Value::operator[](const std::vector& indices) const { return load_at_index(*this, indices); } +std::vector values_to_ops(const std::vector& values) { + std::vector ops; + ops.reserve(values.size()); + for (const auto& value : values) { + if (value.op) { + ops.push_back(value.op); + } else { + throw std::runtime_error("Value contains a null Op pointer"); + } + } + return ops; +} + +std::vector ops_to_values(const std::vector& ops) { + std::vector values; + values.reserve(ops.size()); + for (const auto& op : ops) { + if (op) { + values.emplace_back(op); + } else { + throw std::runtime_error("Op pointer in vector is null"); + } + } + return values; +} + +void Shape::AddDimension(const Value &dim) { + dimensions.push_back(dim); +} + +void Shape::AddDimensions(const std::vector &dims) { + dimensions.insert(dimensions.end(), dims.begin(), dims.end()); +} + +bool Shape::Broadcastable(const Shape &other) const { + size_t min_size = std::min(dimensions.size(), other.dimensions.size()); + for (size_t i = 0; i < min_size; ++i) { + if (!dimensions[i].Compare(other.dimensions[i])) { + return false; + } + } + return true; +} + +Shape ComputeShape(Value x) { + Shape shape; + std::vector parents; + Op* current = x.op; + while(current) { + parents.push_back(current); + current = current->parent_block->parent_op; + } + std::reverse(parents.begin(), parents.end()); + for (const auto& parent : parents) { + OpSpec* spec = GetOpSpec(parent->opcode); + if(spec->props.contains(OpProp::HasShape)) { + shape.AddDimensions(ops_to_values(parent->args->GetInputs(ArgType::Input))); + } + } + return shape; +} + } // namespace TensorFrost From 1cb4c2decc91de5bec319fad13ae276da697ae3c Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz Date: Mon, 16 Jun 2025 06:40:00 +0200 Subject: [PATCH 16/44] Refactoring operation handling --- TensorFrost/PybindModule.cpp | 12 +- TensorFrost/include/Compiler/Common.h | 30 +-- .../include/Compiler/OperationArguments.h | 27 +- .../include/Compiler/OperationRegistry.h | 25 +- TensorFrost/src/Compiler/Common.cpp | 8 +- TensorFrost/src/Compiler/Operation.cpp | 2 +- .../src/Compiler/OperationArguments.cpp | 110 +++------ .../src/Compiler/OperationRegistry.cpp | 230 +++++++++++------- TensorFrost/src/Compiler/Overloads.cpp | 9 +- TensorFrost/src/Compiler/TFProgram.cpp | 5 +- TensorFrost/src/Compiler/Value.cpp | 7 +- 11 files changed, 241 insertions(+), 224 deletions(-) diff --git a/TensorFrost/PybindModule.cpp b/TensorFrost/PybindModule.cpp index 8b597c26..ff1cadc7 100644 --- a/TensorFrost/PybindModule.cpp +++ b/TensorFrost/PybindModule.cpp @@ -141,7 +141,7 @@ PYBIND11_MODULE(TensorFrost, m) { Value f = 2.5f; Value g = 3.5f; Value c = a + b * 3; - Value mem = memory({a, b, c}, TFTypeFloat32); + Value mem = memory({a, b, c}, TFFloat32); inputs.push_back(mem); vmap({a, b}, [&](Value ids0) { Value something = tofloat(mem * sin(f + g)); @@ -151,9 +151,15 @@ PYBIND11_MODULE(TensorFrost, m) { Value d = c + b + ids0[0] * imem; Value m0, m1; if_cond(d > 0, [&]() { - m0 = d * c * imem; + Value t = d * c * imem; + vmap({c}, [&](Value ids1) { + m0 = t * imem[{ids1[0], ids0[0], ids0[1]}]; + }); }, [&]() { - m1 = d * c * imem + 1; + Value t = d * c / imem; + vmap({c}, [&](Value ids1) { + m1 = t / imem[{ids1[0], ids0[0], ids0[1]}]; + }); }); Value result = phi({m0, m1}); vmap({c, c}, [&](Value ids1) { diff --git a/TensorFrost/include/Compiler/Common.h b/TensorFrost/include/Compiler/Common.h index 95a33720..bc8b4ff8 100644 --- a/TensorFrost/include/Compiler/Common.h +++ b/TensorFrost/include/Compiler/Common.h @@ -11,6 +11,7 @@ #include #include #include +#include namespace TensorFrost { extern "C" { @@ -21,6 +22,7 @@ extern "C" { Bool, Tuple, None, + Unknown, }; struct TFDataFormat { @@ -34,12 +36,13 @@ extern "C" { bool operator>(const TFDataFormat& other) const; }; -#define TFTypeNone TFDataFormat{TFType::None, 0} -#define TFTypeTuple TFDataFormat{TFType::Tuple, 0} -#define TFTypeBool32 TFDataFormat{TFType::Bool, 32} -#define TFTypeFloat32 TFDataFormat{TFType::Float, 32} -#define TFTypeInt32 TFDataFormat{TFType::Int, 32} -#define TFTypeUint32 TFDataFormat{TFType::Uint, 32} +#define TFNone TFDataFormat{TFType::None, 0} +#define TFUnknown TFDataFormat{TFType::Unknown, 0} +#define TFTuple TFDataFormat{TFType::Tuple, 0} +#define TFBool TFDataFormat{TFType::Bool, 32} +#define TFFloat32 TFDataFormat{TFType::Float, 32} +#define TFInt32 TFDataFormat{TFType::Int, 32} +#define TFUint32 TFDataFormat{TFType::Uint, 32} } // Utility class to automatically resize and set elements in a vector @@ -65,20 +68,6 @@ struct VecHash { } }; -enum class ArgType { - Input, - Index, - Count, -}; - -inline std::string ToString(ArgType type) { - switch (type) { - case ArgType::Input: return "Input"; - case ArgType::Index: return "Index"; - default: return "Unknown"; - } -} - inline std::string ToString(const TFDataFormat& format) { switch (format.type) { case TFType::Float: return "float" + std::to_string(format.size); @@ -94,7 +83,6 @@ inline std::string ToString(const TFDataFormat& format) { using uint = unsigned int; struct Op; -struct Arguments; struct OpBlock; class OpBlockIterator; struct ArgumentManager; diff --git a/TensorFrost/include/Compiler/OperationArguments.h b/TensorFrost/include/Compiler/OperationArguments.h index ca5f946f..7e841201 100644 --- a/TensorFrost/include/Compiler/OperationArguments.h +++ b/TensorFrost/include/Compiler/OperationArguments.h @@ -4,41 +4,26 @@ namespace TensorFrost { struct Argument { - ArgType type; Op* from = nullptr; Op* to = nullptr; int index = 0; }; -struct Arguments { +struct ArgumentManager { Op* parent_op = nullptr; auto_vector> inputs; std::set> used_at; - void AddInput(ArgType type, Op* from, int index = 0); - bool CheckValidity(bool throw_error = false) const; - void RemoveInput(int index); - - std::vector Args() const; - std::vector Inputs() const; -}; - -struct ArgumentManager { - Op* parent_op = nullptr; - std::array, (int)ArgType::Count> type_args; - ArgumentManager(Op* parent); - void AddArgument(Op &from, ArgType type, int index = 0); + void AddArgument(Op &from, int index = 0); void SetAsOutput(Argument *arg); void RemoveOutput(Argument *arg); - void SetArguments(ArgType type, std::vector args); - void Remove(ArgType type, int index); - void RemoveType(ArgType type); + void SetArguments(std::vector args); + void Remove(int index); void RemoveAll(); - Arguments* Get(ArgType type) const; - Arguments* operator[](ArgType type) const; - std::vector GetInputs(ArgType type) const; + std::vector Args() const; + std::vector Inputs() const; }; } \ No newline at end of file diff --git a/TensorFrost/include/Compiler/OperationRegistry.h b/TensorFrost/include/Compiler/OperationRegistry.h index d522130d..bb5d024b 100644 --- a/TensorFrost/include/Compiler/OperationRegistry.h +++ b/TensorFrost/include/Compiler/OperationRegistry.h @@ -3,8 +3,6 @@ namespace TensorFrost { -using OverloadsMap = std::unordered_map, TFDataFormat, VecHash>; - enum class OpClass { Operator, UnaryOperator, @@ -23,7 +21,6 @@ enum class OpClass { }; enum class OpProp { - Variadic, HasShape, Load, Store, @@ -69,9 +66,29 @@ FoldFn make_fold3(F f) { }; } +enum class ArgProp { + IgnoreShape +}; + +struct ArgSpec { + char out; + std::vector in; + std::map> types; + std::map> props; + bool variadic = false; + + ArgSpec(std::string io, + std::map> types = {}, + std::map> props = {}); + + bool IsValid(std::vector inputs, TFDataFormat output) const; + TFDataFormat EstimateOutputType(const std::vector& inputs) const; + std::vector> InputProperties(const std::vector& inputs) const; +}; + struct OpSpec { std::string name; - OverloadsMap overloads; + ArgSpec arg_spec; OpClass op_class = OpClass::None; std::set props; int blocks = 0; diff --git a/TensorFrost/src/Compiler/Common.cpp b/TensorFrost/src/Compiler/Common.cpp index f16071ed..87b6dace 100644 --- a/TensorFrost/src/Compiler/Common.cpp +++ b/TensorFrost/src/Compiler/Common.cpp @@ -23,13 +23,13 @@ bool TFDataFormat::operator>(const TFDataFormat &other) const { TFDataFormat GetTypeFromAttribute(const Attribute& attr) { if (std::holds_alternative(attr)) { - return TFTypeInt32; + return TFInt32; } else if (std::holds_alternative(attr)) { - return TFTypeUint32; + return TFUint32; } else if (std::holds_alternative(attr)) { - return TFTypeFloat32; + return TFFloat32; } else if (std::holds_alternative(attr)) { - return TFTypeBool32; + return TFBool; } throw std::runtime_error("Unsupported attribute type for TFDataFormat conversion"); } diff --git a/TensorFrost/src/Compiler/Operation.cpp b/TensorFrost/src/Compiler/Operation.cpp index d1425bfb..d1a8ae35 100644 --- a/TensorFrost/src/Compiler/Operation.cpp +++ b/TensorFrost/src/Compiler/Operation.cpp @@ -3,7 +3,7 @@ namespace TensorFrost { Op::Op(std::string op_name): opcode(std::move(op_name)) { args = std::make_unique(this); - type = TFTypeNone; + type = TFNone; } OpBlock* Op::NewBlock() { diff --git a/TensorFrost/src/Compiler/OperationArguments.cpp b/TensorFrost/src/Compiler/OperationArguments.cpp index 7972a3f1..05885541 100644 --- a/TensorFrost/src/Compiler/OperationArguments.cpp +++ b/TensorFrost/src/Compiler/OperationArguments.cpp @@ -1,32 +1,47 @@ #include "Compiler/Operation.h" namespace TensorFrost { -void Arguments::AddInput(ArgType type, Op *from, int index) { - inputs.set_element(index, std::make_unique(Argument{type, from, parent_op, index})); - from->args->SetAsOutput(inputs[index].get()); + +ArgumentManager::ArgumentManager(Op *parent): parent_op(parent) { +} + +void ArgumentManager::AddArgument(Op &from, int index) { + inputs.set_element(index, std::make_unique(Argument{&from, parent_op, index})); + from.args->SetAsOutput(inputs[index].get()); +} + +void ArgumentManager::SetAsOutput(Argument *arg) { + used_at.insert({arg->index, arg}); } -void Arguments::RemoveInput(int index) { - if (index < 0 || index >= inputs.size()) return; +void ArgumentManager::RemoveOutput(Argument *arg) { + used_at.erase({arg->index, arg}); +} + +void ArgumentManager::SetArguments(std::vector args) { + for (size_t i = 0; i < args.size(); ++i) { + AddArgument(*args[i], (int)i); + } +} + +void ArgumentManager::Remove(int index) { + if (index < 0 || index >= inputs.size()) { + throw std::out_of_range("Index out of range for arguments"); + } Argument *arg = inputs[index].get(); if(!arg || !arg->from) throw std::runtime_error("Invalid argument"); arg->from->args->RemoveOutput(arg); inputs[index].reset(); } -bool Arguments::CheckValidity(bool throw_error) const { - for (const auto& input : inputs) { - if (!input || !input->from) { - if (throw_error) { - throw std::runtime_error("Invalid argument"); - } - return false; - } +void ArgumentManager::RemoveAll() { + for (size_t i = 0; i < inputs.size(); ++i) { + Remove((int)i); } - return true; + inputs.clear(); } -std::vector Arguments::Args() const { +std::vector ArgumentManager::Args() const { std::vector result; for (const auto& arg : inputs) { if (arg) { @@ -36,7 +51,7 @@ std::vector Arguments::Args() const { return result; } -std::vector Arguments::Inputs() const { +std::vector ArgumentManager::Inputs() const { std::vector result; for (const auto& arg : inputs) { if (arg && arg->from) { @@ -45,67 +60,4 @@ std::vector Arguments::Inputs() const { } return result; } - -ArgumentManager::ArgumentManager(Op *parent): parent_op(parent) { - for (int i = 0; i < (int)ArgType::Count; ++i) { - type_args[i] = std::make_unique(); - type_args[i]->parent_op = parent; - } -} - -void ArgumentManager::AddArgument(Op &from, ArgType type, int index) { - type_args[(int)type]->AddInput(type, &from, index); -} - -void ArgumentManager::SetAsOutput(Argument *arg) { - type_args[(int)arg->type]->used_at.insert({arg->index, arg}); -} - -void ArgumentManager::RemoveOutput(Argument *arg) { - type_args[(int)arg->type]->used_at.erase({arg->index, arg}); -} - -void ArgumentManager::SetArguments(ArgType type, std::vector args) { - for (size_t i = 0; i < args.size(); ++i) { - AddArgument(*args[i], type, (int)i); - } -} - -void ArgumentManager::Remove(ArgType type, int index) { - if (index < 0 || index >= type_args[(int)type]->inputs.size()) { - throw std::out_of_range("Index out of range for argument type " + ToString(type)); - } - type_args[(int)type]->RemoveInput(index); -} - -void ArgumentManager::RemoveType(ArgType type) { - auto& args = type_args[(int)type]; - for (size_t i = 0; i < args->inputs.size(); ++i) { - args->RemoveInput((int)i); - } - args->inputs.clear(); -} - -void ArgumentManager::RemoveAll() { - for (int i = 0; i < (int)ArgType::Count; ++i) { - RemoveType((ArgType)i); - } -} - -Arguments* ArgumentManager::Get(ArgType type) const { - return type_args[(int)type].get(); -} - -Arguments * ArgumentManager::operator[](ArgType type) const { - return Get(type); -} - -std::vector ArgumentManager::GetInputs(ArgType type) const { - auto *args = Get(type); - std::vector inputs; - for (const auto& arg : args->inputs) { - inputs.push_back(arg->from); - } - return inputs; -} } diff --git a/TensorFrost/src/Compiler/OperationRegistry.cpp b/TensorFrost/src/Compiler/OperationRegistry.cpp index b6906b35..fc126023 100644 --- a/TensorFrost/src/Compiler/OperationRegistry.cpp +++ b/TensorFrost/src/Compiler/OperationRegistry.cpp @@ -3,68 +3,112 @@ using namespace std; namespace TensorFrost { -TFDataFormat OpSpec::GetOutputType(const std::vector &args) const { - if (props.contains(OpProp::Variadic) || args.empty()) { - return overloads.find({})->second; +ArgSpec::ArgSpec(std::string io, std::map> types, + std::map> props) { + this->props = std::move(props); + this->types = std::move(types); + if (io.empty()) { + throw std::invalid_argument("Argument specification cannot be empty"); } - auto it = overloads.find(args); - if (it == overloads.end()) { - std::string error_msg = "No overload found for operation: " + name + " with args: ("; - for (const auto& arg : args) { - error_msg += ToString(arg) + ", "; - } - if (!args.empty()) { - error_msg.pop_back(); // Remove last comma - error_msg.pop_back(); // Remove last space - } - error_msg += ")"; - - throw std::runtime_error(error_msg); + // Parse argument specification (e.g., "x(x,y,t)" -> out = 'x', in = {'x', 'y', 't'}, or "z(y,z,...)" -> out = 'z', in = {'y', 'z'}, variadic = true) + out = io[0]; + //get substring between parentheses + size_t start = io.find('('); + size_t end = io.find(')', start); + //split the substring by commas + std::stringstream ss(io.substr(start + 1, end - start - 1)); + std::string token; + while (std::getline(ss, token, ',')) { + if (token == "...") { + variadic = true; + break; // Variadic argument, stop parsing further + } + if (token.empty()) continue; // Skip empty tokens + in.push_back(token[0]); } - return it->second; } -static const std::unordered_map tok = { - {"f", TFTypeFloat32}, - {"i", TFTypeInt32}, - {"u", TFTypeUint32}, - {"tuple", TFTypeTuple}, - {"b", TFTypeBool32}, - {"void", TFTypeNone}, -}; +bool ArgSpec::IsValid(std::vector inputs, TFDataFormat output) const { + if (variadic) { + if (in.empty() || inputs.empty()) return false; + } else { + if (inputs.size() != in.size()) return false; + } + + auto name_of = [&](size_t i) -> const char& { + return variadic ? in.front() : in[i]; + }; + + std::unordered_map seen; + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& n = name_of(i); + if (seen.count(n) && !(seen[n] == inputs[i])) + throw std::runtime_error("Conflicting types for arg " + n); + seen[n] = inputs[i]; + + auto a = types.find(n); + if (a != types.end() && !a->second.count(inputs[i])) + return false; + } + + auto ao = types.find(out); + if (ao != types.end() && !ao->second.count(output)) return false; + if (seen.count(out) && !(seen[out] == output)) return false; + + if (variadic) { + for (size_t i = 1; i < inputs.size(); ++i) + if (!(inputs[i] == inputs[0])) return false; + } + return true; +} -static std::string trim(std::string_view s) { - size_t a = 0, b = s.size(); - while (a < b && std::isspace(static_cast(s[a]))) ++a; - while (b > a && std::isspace(static_cast(s[b - 1]))) --b; - return std::string{s.substr(a, b - a)}; +TFDataFormat ArgSpec::EstimateOutputType(const std::vector &inputs) const { + if (variadic && inputs.empty()) return TFUnknown; + + auto name_of = [&](size_t i) -> const char& { + return variadic ? in.front() : in[i]; + }; + + std::unordered_map seen; + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& n = name_of(i); + if (seen.count(n) && !(seen[n] == inputs[i])) + throw std::runtime_error("Conflicting types for arg " + n); + seen[n] = inputs[i]; + } + + if (seen.count(out)) return seen[out]; + + auto ao = types.find(out); + if (ao != types.end() && ao->second.size() == 1) + return *ao->second.begin(); + + return TFUnknown; } -OverloadsMap GenerateOverloadMap(const std::string& input) { - OverloadsMap out; - std::stringstream ss(input); - std::string stmt; - while (std::getline(ss, stmt, ';')) { - stmt = trim(stmt); - if (stmt.empty()) continue; - auto l = stmt.find('('), r = stmt.find(')'); - if (l == std::string::npos || r == std::string::npos || r < l) throw std::runtime_error("Overload syntax error: " + stmt); - auto tgt = trim(stmt.substr(0, l)); - auto args = stmt.substr(l + 1, r - l - 1); - std::vector key; - std::stringstream as(args); - std::string tokarg; - while (std::getline(as, tokarg, ',')) { - tokarg = trim(tokarg); - key.push_back(tok.at(tokarg)); - } - out.emplace(std::move(key), tok.at(tgt)); +std::vector> ArgSpec::InputProperties(const std::vector &inputs) const { + if ((!variadic && inputs.size() != in.size()) || + (variadic && (in.empty() || inputs.empty()))) + return {}; + + auto name_of = [&](size_t i) -> const char& { + return variadic ? in.front() : in[i]; + }; + + std::vector> out; + out.reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& n = name_of(i); + auto p = props.find(n); + out.push_back(p == props.end() ? std::set() : p->second); } return out; } -#define DEF_OP(op_name, overload_str, operation_class, ...) \ - OpSpec{ .name = op_name, .overloads = GenerateOverloadMap(overload_str), .op_class = operation_class, __VA_ARGS__ } +TFDataFormat OpSpec::GetOutputType(const std::vector &args) const { + TFDataFormat ret = arg_spec.EstimateOutputType(args); + return ret; +} #define BIN_OP_FOLD(op) \ make_fold2([](auto a, auto b) { \ @@ -87,41 +131,63 @@ make_fold2([](auto a, auto b) { \ #define TERN_FUNC_FOLD(op) \ make_fold3([](auto a, auto b, auto c) { return op(a, b, c); }) +#define DEF_OP(op_name, overload_str, operation_class, ...) \ + OpSpec{ .name = op_name, .arg_spec = ArgSpec overload_str, .op_class = operation_class, __VA_ARGS__ } + vector default_operations = { - DEF_OP("memory", "f(); u(); i(); b(); tuple()", OpClass::Memory, .props = {OpProp::Variadic, OpProp::HasShape}), - DEF_OP("load", "f(f); u(u); i(i); b(b)", OpClass::Function, .props = {OpProp::Load, OpProp::MemoryOp}), - DEF_OP("store", "f(f); u(u); i(i); b(b)", OpClass::Function, .props = {OpProp::Store, OpProp::MemoryOp}), - - DEF_OP("const", "f(); u(); i(); b(); tuple()", OpClass::Constant), - DEF_OP("copy", "f(f); u(u); i(i); b(b)", OpClass::Copy), - DEF_OP("add", "f(f,f); u(u,u); i(i,i)", OpClass::Operator, .const_fold = BIN_OP_FOLD(+)), - DEF_OP("sub", "f(f,f); u(u,u); i(i,i)", OpClass::Operator, .const_fold = BIN_OP_FOLD(-)), - DEF_OP("mul", "f(f,f); u(u,u); i(i,i)", OpClass::Operator, .const_fold = BIN_OP_FOLD(*)), - DEF_OP("div", "f(f,f); u(u,u); i(i,i)", OpClass::Operator, .const_fold = BIN_OP_FOLD(/)), - DEF_OP("sin", "f(f); u(u); i(i)", OpClass::UnaryOperator, .const_fold = UN_FUNC_FOLD(std::sinf)), - DEF_OP("cos", "f(f); u(u); i(i)", OpClass::UnaryOperator, .const_fold = UN_FUNC_FOLD(std::cosf)), - DEF_OP("tan", "f(f); u(u); i(i)", OpClass::UnaryOperator, .const_fold = UN_FUNC_FOLD(std::tanf)), - - DEF_OP("eq", "b(f,f); b(u,u); b(i,i)", OpClass::Operator, .const_fold = BIN_FUNC_FOLD(std::equal_to<>())), - DEF_OP("ne", "b(f,f); b(u,u); b(i,i)", OpClass::Operator, .const_fold = BIN_FUNC_FOLD(std::not_equal_to<>())), - DEF_OP("lt", "b(f,f); b(u,u); b(i,i)", OpClass::Operator, .const_fold = BIN_FUNC_FOLD(std::less<>())), - DEF_OP("le", "b(f,f); b(u,u); b(i,i)", OpClass::Operator, .const_fold = BIN_FUNC_FOLD(std::less_equal<>())), - DEF_OP("gt", "b(f,f); b(u,u); b(i,i)", OpClass::Operator, .const_fold = BIN_FUNC_FOLD(std::greater<>())), - DEF_OP("ge", "b(f,f); b(u,u); b(i,i)", OpClass::Operator, .const_fold = BIN_FUNC_FOLD(std::greater_equal<>())), - - DEF_OP("tofloat", "f(f); f(i); f(u); f(b)", OpClass::Function, .const_fold = UN_FUNC_FOLD(static_cast)), - DEF_OP("toint", "i(f); i(i); i(u); i(b)", OpClass::Function, .const_fold = UN_FUNC_FOLD(static_cast)), - DEF_OP("touint", "u(f); u(i); u(u); u(b)", OpClass::Function, .const_fold = UN_FUNC_FOLD(static_cast)), - DEF_OP("tobool", "b(f); b(i); b(u); b(b)", OpClass::Function, .const_fold = UN_FUNC_FOLD(static_cast)), - - DEF_OP("unpack_tuple_int", "i(tuple)", OpClass::Function), + DEF_OP("memory", ("x(y,...)"), OpClass::Memory, .props = {OpProp::HasShape}), + DEF_OP("load", ("x(x,y,...)", {{'y', {TFInt32}}}, {{'x', {ArgProp::IgnoreShape}}}), + OpClass::Function, .props = {OpProp::Load, OpProp::MemoryOp}), + DEF_OP("store", ("x(x,x,y,...)", {{'y', {TFInt32}}}, {{'x', {ArgProp::IgnoreShape}}}), + OpClass::Function, .props = {OpProp::Store, OpProp::MemoryOp}), + + DEF_OP("const", ("x()"), OpClass::Constant), + DEF_OP("copy", ("x(x)"), OpClass::Copy), + DEF_OP("add", ("x(x,x)"), OpClass::Operator, + .const_fold = BIN_OP_FOLD(+)), + DEF_OP("sub", ("x(x,x)"), OpClass::Operator, + .const_fold = BIN_OP_FOLD(-)), + DEF_OP("mul", ("x(x,x)"), OpClass::Operator, + .const_fold = BIN_OP_FOLD(*)), + DEF_OP("div", ("x(x,x)"), OpClass::Operator, + .const_fold = BIN_OP_FOLD(/)), + DEF_OP("sin", ("x(x)"), OpClass::UnaryOperator, + .const_fold = UN_FUNC_FOLD(std::sinf)), + DEF_OP("cos", ("x(x)"), OpClass::UnaryOperator, + .const_fold = UN_FUNC_FOLD(std::cosf)), + DEF_OP("tan", ("x(x)"), OpClass::UnaryOperator, + .const_fold = UN_FUNC_FOLD(std::tanf)), + + DEF_OP("eq", ("x(y,y)", {{'x', {TFBool}}}), OpClass::Operator, + .const_fold = BIN_FUNC_FOLD(std::equal_to<>())), + DEF_OP("ne", ("x(y,y)", {{'x', {TFBool}}}), OpClass::Operator, + .const_fold = BIN_FUNC_FOLD(std::not_equal_to<>())), + DEF_OP("lt", ("x(y,y)", {{'x', {TFBool}}}), OpClass::Operator, + .const_fold = BIN_FUNC_FOLD(std::less<>())), + DEF_OP("le", ("x(y,y)", {{'x', {TFBool}}}), OpClass::Operator, + .const_fold = BIN_FUNC_FOLD(std::less_equal<>())), + DEF_OP("gt", ("x(y,y)", {{'x', {TFBool}}}), OpClass::Operator, + .const_fold = BIN_FUNC_FOLD(std::greater<>())), + DEF_OP("ge", ("x(y,y)", {{'x', {TFBool}}}), OpClass::Operator, + .const_fold = BIN_FUNC_FOLD(std::greater_equal<>())), + + DEF_OP("tofloat", ("x(y)", {{'x', {TFFloat32}}}), OpClass::Function, + .const_fold = UN_FUNC_FOLD(static_cast)), + DEF_OP("toint", ("x(y)", {{'x', {TFInt32}}}), OpClass::Function, + .const_fold = UN_FUNC_FOLD(static_cast)), + DEF_OP("touint", ("x(y)", {{'x', {TFUint32}}}), OpClass::Function, + .const_fold = UN_FUNC_FOLD(static_cast)), + DEF_OP("tobool", ("x(y)", {{'x', {TFBool}}}), OpClass::Function, + .const_fold = UN_FUNC_FOLD(static_cast)), + + DEF_OP("unpack", ("x(y)"), OpClass::Function), // Operations with blocks - DEF_OP("vmap", "tuple()", OpClass::Parallel, .props = {OpProp::Variadic, OpProp::HasShape}, .blocks = 1), - DEF_OP("if_cond", "void(b)", OpClass::Function, .blocks = 2), - DEF_OP("loop", "i(i,i,i)", OpClass::Function, .blocks = 1), + DEF_OP("vmap", ("x(y,...)", {{'x', {TFTuple}}, {'y', {TFInt32}}}), OpClass::Parallel, .props = {OpProp::HasShape}, .blocks = 1), + DEF_OP("if_cond", ("x(y)", {{'x', {TFNone}}, {'y', {TFBool}}}), OpClass::Function, .blocks = 2), + DEF_OP("loop", ("x(x,x,x)", {{'x', {TFInt32}}}), OpClass::Function, .blocks = 1), - DEF_OP("phi", "f(); u(); i(); b()", OpClass::Phi, .props = {OpProp::Variadic}), + DEF_OP("phi", ("x(x,...)"), OpClass::Phi), }; std::unordered_map> CreateOperationRegistry() { diff --git a/TensorFrost/src/Compiler/Overloads.cpp b/TensorFrost/src/Compiler/Overloads.cpp index b8e6df70..77a95117 100644 --- a/TensorFrost/src/Compiler/Overloads.cpp +++ b/TensorFrost/src/Compiler/Overloads.cpp @@ -6,7 +6,7 @@ using namespace std; namespace TensorFrost { // General function to create an Op instance in the current execution context -Value make_op(std::string op, std::vector ids, std::vector args) { +Value make_op(std::string op, std::vector args, std::vector mem) { OpSpec* spec = GetOpSpec(op); vector arg_types; for (const auto& arg : args) { @@ -22,12 +22,11 @@ Value make_op(std::string op, std::vector ids, std::vector args) { for (int i = 0; i < spec->blocks; ++i) { op_instance->NewBlock(); } + op_instance = &GetContext()->Add(std::unique_ptr(op_instance)); + Shape shape = ComputeShape(Value(op_instance)); - return Value(&GetContext()->Add(std::unique_ptr(op_instance))); -} -Value func_op(const std::string &name, std::vector args) { - return make_op(name, {}, std::move(args)); + return Value(op_instance); } Value constant(Attribute value) { diff --git a/TensorFrost/src/Compiler/TFProgram.cpp b/TensorFrost/src/Compiler/TFProgram.cpp index 0f3a8e6e..d05a0281 100644 --- a/TensorFrost/src/Compiler/TFProgram.cpp +++ b/TensorFrost/src/Compiler/TFProgram.cpp @@ -51,8 +51,9 @@ void TFProgram::RemoveUnused() { // Converts multilevel vmap operations into a sequence of vmaps with concatenated shape void TFProgram::CombineVmapDepthwise() { IterateOver(*GetBaseBlock(), [&](OpBlock::Iterator& it) { - static OpBlock* last_block = nullptr; - OpBlock* current_block = it->parent_block; + static OpBlock* last_vmap_block = nullptr; + static OpBlock* current_vmap_block = nullptr; + static Shape current_shape; }); } diff --git a/TensorFrost/src/Compiler/Value.cpp b/TensorFrost/src/Compiler/Value.cpp index b9e302a7..b3a123c8 100644 --- a/TensorFrost/src/Compiler/Value.cpp +++ b/TensorFrost/src/Compiler/Value.cpp @@ -131,8 +131,11 @@ void Shape::AddDimensions(const std::vector &dims) { } bool Shape::Broadcastable(const Shape &other) const { - size_t min_size = std::min(dimensions.size(), other.dimensions.size()); - for (size_t i = 0; i < min_size; ++i) { + size_t size = other.dimensions.size(); + if (dimensions.size() < size) { + throw std::runtime_error("Other shape has more dimensions than this shape"); + } + for (size_t i = 0; i < size; ++i) { if (!dimensions[i].Compare(other.dimensions[i])) { return false; } From dc4ae1b9813aa5e1d73e386ed113aae0b6257af3 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz Date: Tue, 17 Jun 2025 04:46:20 +0200 Subject: [PATCH 17/44] Rewrite tuples as indexed values --- TensorFrost/PybindModule.cpp | 14 ++-- TensorFrost/include/Compiler/Common.h | 26 +++++- TensorFrost/include/Compiler/Operation.h | 1 + .../include/Compiler/OperationArguments.h | 14 ++-- .../include/Compiler/OperationRegistry.h | 3 + TensorFrost/include/Compiler/Overloads.h | 29 +++---- TensorFrost/include/Compiler/Printer.h | 2 +- TensorFrost/include/Compiler/Value.h | 22 ++--- TensorFrost/src/Compiler/Operation.cpp | 5 +- .../src/Compiler/OperationArguments.cpp | 36 ++++----- .../src/Compiler/OperationRegistry.cpp | 16 ++-- TensorFrost/src/Compiler/Overloads.cpp | 80 ++++++++++--------- TensorFrost/src/Compiler/Printer.cpp | 28 ++----- TensorFrost/src/Compiler/TFProgram.cpp | 4 +- TensorFrost/src/Compiler/Value.cpp | 69 ++++++++-------- 15 files changed, 179 insertions(+), 170 deletions(-) diff --git a/TensorFrost/PybindModule.cpp b/TensorFrost/PybindModule.cpp index ff1cadc7..73ff75b5 100644 --- a/TensorFrost/PybindModule.cpp +++ b/TensorFrost/PybindModule.cpp @@ -134,8 +134,8 @@ PYBIND11_MODULE(TensorFrost, m) { // TEST CODE TFProgram program([]() -> auto { - std::vector inputs; - std::vector outputs; + Values inputs; + Values outputs; Value a = 5; Value b = 10; Value f = 2.5f; @@ -143,26 +143,26 @@ PYBIND11_MODULE(TensorFrost, m) { Value c = a + b * 3; Value mem = memory({a, b, c}, TFFloat32); inputs.push_back(mem); - vmap({a, b}, [&](Value ids0) { + vmap({a, b}, [&](Values ids0) { Value something = tofloat(mem * sin(f + g)); }); - vmap({a, b, c}, [&](Value ids0) { + vmap({a, b, c}, [&](Values ids0) { Value imem = toint(mem * sin(f + g)); Value d = c + b + ids0[0] * imem; Value m0, m1; if_cond(d > 0, [&]() { Value t = d * c * imem; - vmap({c}, [&](Value ids1) { + vmap({c}, [&](Values ids1) { m0 = t * imem[{ids1[0], ids0[0], ids0[1]}]; }); }, [&]() { Value t = d * c / imem; - vmap({c}, [&](Value ids1) { + vmap({c}, [&](Values ids1) { m1 = t / imem[{ids1[0], ids0[0], ids0[1]}]; }); }); Value result = phi({m0, m1}); - vmap({c, c}, [&](Value ids1) { + vmap({c, c}, [&](Values ids1) { Value m = result * imem[{ids1[1], ids1[0], ids0[0]}]; outputs.push_back(m); }); diff --git a/TensorFrost/include/Compiler/Common.h b/TensorFrost/include/Compiler/Common.h index bc8b4ff8..7d4179d2 100644 --- a/TensorFrost/include/Compiler/Common.h +++ b/TensorFrost/include/Compiler/Common.h @@ -20,7 +20,6 @@ extern "C" { Uint, Int, Bool, - Tuple, None, Unknown, }; @@ -38,7 +37,6 @@ extern "C" { #define TFNone TFDataFormat{TFType::None, 0} #define TFUnknown TFDataFormat{TFType::Unknown, 0} -#define TFTuple TFDataFormat{TFType::Tuple, 0} #define TFBool TFDataFormat{TFType::Bool, 32} #define TFFloat32 TFDataFormat{TFType::Float, 32} #define TFInt32 TFDataFormat{TFType::Int, 32} @@ -74,8 +72,7 @@ inline std::string ToString(const TFDataFormat& format) { case TFType::Uint: return "uint" + std::to_string(format.size); case TFType::Int: return "int" + std::to_string(format.size); case TFType::Bool: return "bool" + std::to_string(format.size); - case TFType::Tuple: return "tuple"; - case TFType::None: return "none"; + case TFType::None: return "void"; default: return "unknown"; } } @@ -93,6 +90,7 @@ struct Shape; using Attribute = std::variant; using AttributeMap = std::unordered_map; using AttributeVector = std::vector; +using Values = std::vector; TFDataFormat GetTypeFromAttribute(const Attribute& attr); @@ -118,6 +116,26 @@ auto TransformVector(const Container& input, Func func) { } return output; } + +template +auto ConcatVectors(const std::vector& a, const std::vector& b) { + std::vector result; + result.reserve(a.size() + b.size()); + result.insert(result.end(), a.begin(), a.end()); + result.insert(result.end(), b.begin(), b.end()); + return result; +} + +template +auto SliceVector(const std::vector& vec, size_t start, size_t end = -1) { + if (end == -1 || end > vec.size()) { + end = vec.size(); + } + if (start >= end || start >= vec.size()) { + return std::vector(); + } + return std::vector(vec.begin() + start, vec.begin() + end); +} } namespace std { diff --git a/TensorFrost/include/Compiler/Operation.h b/TensorFrost/include/Compiler/Operation.h index 52c2b689..a6c61ab5 100644 --- a/TensorFrost/include/Compiler/Operation.h +++ b/TensorFrost/include/Compiler/Operation.h @@ -15,6 +15,7 @@ struct Op { AttributeMap attributes; TFDataFormat type; std::vector> blocks; + int output_count = 1; OpBlock* parent_block = nullptr; size_t index = 0; //might not be up to date diff --git a/TensorFrost/include/Compiler/OperationArguments.h b/TensorFrost/include/Compiler/OperationArguments.h index 7e841201..118f2a36 100644 --- a/TensorFrost/include/Compiler/OperationArguments.h +++ b/TensorFrost/include/Compiler/OperationArguments.h @@ -6,7 +6,10 @@ namespace TensorFrost { struct Argument { Op* from = nullptr; Op* to = nullptr; - int index = 0; + int arg_index = 0; // Index in to's arguments + int from_index = 0; // Index of from's output + + Value From() const; }; struct ArgumentManager { @@ -15,15 +18,14 @@ struct ArgumentManager { std::set> used_at; ArgumentManager(Op* parent); - void AddArgument(Op &from, int index = 0); + void AddArgument(Value from, int arg_index = 0); void SetAsOutput(Argument *arg); void RemoveOutput(Argument *arg); - void SetArguments(std::vector args); + void SetArguments(Values args); void Remove(int index); void RemoveAll(); - std::vector Args() const; - std::vector Inputs() const; + Values Inputs() const; }; -} \ No newline at end of file +} diff --git a/TensorFrost/include/Compiler/OperationRegistry.h b/TensorFrost/include/Compiler/OperationRegistry.h index bb5d024b..d21eb599 100644 --- a/TensorFrost/include/Compiler/OperationRegistry.h +++ b/TensorFrost/include/Compiler/OperationRegistry.h @@ -29,6 +29,7 @@ enum class OpProp { }; using FoldFn = std::function; +using CalcTuple = std::function; [[noreturn]] inline void bad_arity(std::size_t expect, std::size_t got) { @@ -93,8 +94,10 @@ struct OpSpec { std::set props; int blocks = 0; FoldFn const_fold = nullptr; + CalcTuple calc_tuple = nullptr; TFDataFormat GetOutputType(const std::vector& args) const; + bool IsValid(const std::vector& inputs, TFDataFormat output) const; }; void RegisterOperation(const OpSpec& spec); diff --git a/TensorFrost/include/Compiler/Overloads.h b/TensorFrost/include/Compiler/Overloads.h index b1df8e38..e58d2cbe 100644 --- a/TensorFrost/include/Compiler/Overloads.h +++ b/TensorFrost/include/Compiler/Overloads.h @@ -3,27 +3,28 @@ #include "Value.h" namespace TensorFrost { -Value make_op(std::string op, std::vector ids = {}, std::vector args = {}); -Value func_op(const std::string &name, std::vector args = {}); +std::pair create_op(std::string op, const Values& args, TFDataFormat output_type = TFUnknown); +Value value_op(std::string op, Values args = {}, TFDataFormat output_type = TFUnknown); +Values tuple_op(std::string op, Values args = {}, TFDataFormat output_type = TFUnknown); + Value constant(int value); Value constant(uint value); Value constant(float value); Value constant(bool value); -Value unpack_tuple(Value x, int index = 0); -Value vmap(std::vector shape, std::function body); -Value memory(std::vector shape, TFDataFormat type); -Value load_at_index(Value mem, std::vector indices); +void vmap(Values shape, std::function body); +Value memory(Values shape, TFDataFormat type); +Value load_at_index(Value mem, Values indices); void if_cond(Value cond, std::function body_true, std::function body_false = nullptr); Value loop(Value start, Value end, Value step, std::function body); -Value phi(std::vector inputs); +Value phi(Values inputs); -inline Value toint(Value x) { return func_op("toint", {x}); } -inline Value tofloat(Value x) { return func_op("tofloat", {x}); } -inline Value touint(Value x) { return func_op("touint", {x}); } -inline Value tobool(Value x) { return func_op("tobool", {x}); } +inline Value toint(Value x) { return value_op("toint", {x}); } +inline Value tofloat(Value x) { return value_op("tofloat", {x}); } +inline Value touint(Value x) { return value_op("touint", {x}); } +inline Value tobool(Value x) { return value_op("tobool", {x}); } -inline Value sin(Value x) { return func_op("sin", {x}); } -inline Value cos(Value x) { return func_op("cos", {x}); } -inline Value tan(Value x) { return func_op("tan", {x}); } +inline Value sin(Value x) { return value_op("sin", {x}); } +inline Value cos(Value x) { return value_op("cos", {x}); } +inline Value tan(Value x) { return value_op("tan", {x}); } } diff --git a/TensorFrost/include/Compiler/Printer.h b/TensorFrost/include/Compiler/Printer.h index 8a1b7a8d..4ba12306 100644 --- a/TensorFrost/include/Compiler/Printer.h +++ b/TensorFrost/include/Compiler/Printer.h @@ -5,7 +5,7 @@ namespace TensorFrost { std::string VariableName(const Op* op); -void PrintOp(const Op& op, std::ostream& os); +std::string PrintOp(const Op* op); std::string PrintBlock(OpBlock& base_block); void AssignVariableNames(OpBlock &block); std::string PrintAttribute(Attribute attr); diff --git a/TensorFrost/include/Compiler/Value.h b/TensorFrost/include/Compiler/Value.h index dba399fb..a6624d89 100644 --- a/TensorFrost/include/Compiler/Value.h +++ b/TensorFrost/include/Compiler/Value.h @@ -3,23 +3,23 @@ namespace TensorFrost { -// Op wrapper class for overloaded mathematics and operations +// Op wrapper class for overloaded mathematics and manipulations class Value { public: Op* op = nullptr; + int out_index = 0; // Index of the output value in the operation Value() = default; - Value(Op* operation); - Value(const Op* operation); + Value(Op* operation, int from_index = 0); + Value(const Op* operation, int from_index = 0); Value(float value); Value(int value); Value(uint value); Value(bool value); - Value(const Value& other) : op(other.op) {} + Value(const Value& other); // indexed access - Value operator[](int index) const; - Value operator[](const std::vector& indices) const; + Value operator[](const Values& indices) const; // binary ops take const ref and are const themselves Value operator+(const Value& other) const; @@ -47,18 +47,18 @@ class Value { bool Compare(const Value& other) const; }; -std::vector values_to_ops(const std::vector& values); -std::vector ops_to_values(const std::vector& ops); +std::vector values_to_ops(const Values& values); +Values ops_to_values(const std::vector& ops); struct Shape { - std::vector dimensions; - Shape(std::vector dims) : dimensions(std::move(dims)) {} + Values dimensions; + Shape(Values dims) : dimensions(std::move(dims)) {} Shape(std::initializer_list dims) : dimensions(dims) {} Shape() = default; Shape(const Shape& other) : dimensions(other.dimensions) {} void AddDimension(const Value& dim); - void AddDimensions(const std::vector& dims); + void AddDimensions(const Values& dims); bool Broadcastable(const Shape& other) const; }; diff --git a/TensorFrost/src/Compiler/Operation.cpp b/TensorFrost/src/Compiler/Operation.cpp index d1a8ae35..9785fcd4 100644 --- a/TensorFrost/src/Compiler/Operation.cpp +++ b/TensorFrost/src/Compiler/Operation.cpp @@ -83,10 +83,7 @@ std::set CollectDependencies(std::vector ops) { std::function collect_dependencies = [&](Op* op) { if (op == nullptr || dependencies.contains(op)) return; // Already processed dependencies.insert(op); - for (auto& input : op->args->Get(ArgType::Input)->inputs) { - collect_dependencies(input->from); - } - for (auto& input : op->args->Get(ArgType::Index)->inputs) { + for (auto& input : op->args->inputs) { collect_dependencies(input->from); } collect_dependencies(op->parent_block->parent_op); // Parent depends on this operation diff --git a/TensorFrost/src/Compiler/OperationArguments.cpp b/TensorFrost/src/Compiler/OperationArguments.cpp index 05885541..719c7f59 100644 --- a/TensorFrost/src/Compiler/OperationArguments.cpp +++ b/TensorFrost/src/Compiler/OperationArguments.cpp @@ -2,25 +2,29 @@ namespace TensorFrost { +Value Argument::From() const { + return Value(from, from_index); +} + ArgumentManager::ArgumentManager(Op *parent): parent_op(parent) { } -void ArgumentManager::AddArgument(Op &from, int index) { - inputs.set_element(index, std::make_unique(Argument{&from, parent_op, index})); - from.args->SetAsOutput(inputs[index].get()); +void ArgumentManager::AddArgument(Value from, int arg_index) { + inputs.set_element(arg_index, std::make_unique(Argument{from.op, parent_op, arg_index, from.out_index})); + from.op->args->SetAsOutput(inputs[arg_index].get()); } void ArgumentManager::SetAsOutput(Argument *arg) { - used_at.insert({arg->index, arg}); + used_at.insert({arg->arg_index, arg}); } void ArgumentManager::RemoveOutput(Argument *arg) { - used_at.erase({arg->index, arg}); + used_at.erase({arg->arg_index, arg}); } -void ArgumentManager::SetArguments(std::vector args) { +void ArgumentManager::SetArguments(Values args) { for (size_t i = 0; i < args.size(); ++i) { - AddArgument(*args[i], (int)i); + AddArgument(args[i], (int)i); } } @@ -41,22 +45,10 @@ void ArgumentManager::RemoveAll() { inputs.clear(); } -std::vector ArgumentManager::Args() const { - std::vector result; - for (const auto& arg : inputs) { - if (arg) { - result.push_back(arg.get()); - } - } - return result; -} - -std::vector ArgumentManager::Inputs() const { - std::vector result; +Values ArgumentManager::Inputs() const { + Values result; for (const auto& arg : inputs) { - if (arg && arg->from) { - result.push_back(arg->from); - } + if (arg) result.push_back(arg->From()); } return result; } diff --git a/TensorFrost/src/Compiler/OperationRegistry.cpp b/TensorFrost/src/Compiler/OperationRegistry.cpp index fc126023..4b91d002 100644 --- a/TensorFrost/src/Compiler/OperationRegistry.cpp +++ b/TensorFrost/src/Compiler/OperationRegistry.cpp @@ -63,8 +63,6 @@ bool ArgSpec::IsValid(std::vector inputs, TFDataFormat output) con } TFDataFormat ArgSpec::EstimateOutputType(const std::vector &inputs) const { - if (variadic && inputs.empty()) return TFUnknown; - auto name_of = [&](size_t i) -> const char& { return variadic ? in.front() : in[i]; }; @@ -110,6 +108,10 @@ TFDataFormat OpSpec::GetOutputType(const std::vector &args) const return ret; } +bool OpSpec::IsValid(const std::vector& inputs, TFDataFormat output) const { + return arg_spec.IsValid(inputs, output); +} + #define BIN_OP_FOLD(op) \ make_fold2([](auto a, auto b) { \ if constexpr (std::is_same_v || std::is_same_v) { \ @@ -180,10 +182,14 @@ vector default_operations = { DEF_OP("tobool", ("x(y)", {{'x', {TFBool}}}), OpClass::Function, .const_fold = UN_FUNC_FOLD(static_cast)), - DEF_OP("unpack", ("x(y)"), OpClass::Function), - // Operations with blocks - DEF_OP("vmap", ("x(y,...)", {{'x', {TFTuple}}, {'y', {TFInt32}}}), OpClass::Parallel, .props = {OpProp::HasShape}, .blocks = 1), + DEF_OP("vmap", ("x(x,...)", {{'x', {TFInt32}}}), OpClass::Parallel, .props = {OpProp::HasShape}, .blocks = 1, + .calc_tuple = [](Op* op, Values args) -> Values { + Values tuple; + op->output_count = args.size(); + for (size_t i = 0; i < args.size(); ++i) tuple.push_back(Value(op, i)); + return tuple; + }), DEF_OP("if_cond", ("x(y)", {{'x', {TFNone}}, {'y', {TFBool}}}), OpClass::Function, .blocks = 2), DEF_OP("loop", ("x(x,x,x)", {{'x', {TFInt32}}}), OpClass::Function, .blocks = 1), diff --git a/TensorFrost/src/Compiler/Overloads.cpp b/TensorFrost/src/Compiler/Overloads.cpp index 77a95117..352e3be2 100644 --- a/TensorFrost/src/Compiler/Overloads.cpp +++ b/TensorFrost/src/Compiler/Overloads.cpp @@ -1,22 +1,28 @@ #include "Compiler/Operation.h" #include "Compiler/ExecutionContext.h" #include "Compiler/Value.h" +#include "Compiler/Printer.h" using namespace std; namespace TensorFrost { + // General function to create an Op instance in the current execution context -Value make_op(std::string op, std::vector args, std::vector mem) { +std::pair create_op(std::string op, const Values& args, TFDataFormat output_type) { OpSpec* spec = GetOpSpec(op); vector arg_types; for (const auto& arg : args) { arg_types.push_back(arg.op->type); } - TFDataFormat output_type = spec->GetOutputType(arg_types); + if (output_type == TFUnknown) { + output_type = spec->GetOutputType(arg_types); + } + if (output_type == TFUnknown) { + throw std::runtime_error("Cannot determine output type for operation '" + op + "'"); + } Op* op_instance = new Op(op); op_instance->type = output_type; - op_instance->args->SetArguments(ArgType::Index, values_to_ops(ids)); - op_instance->args->SetArguments(ArgType::Input, values_to_ops(args)); + op_instance->args->SetArguments(args); // Create blocks for (int i = 0; i < spec->blocks; ++i) { @@ -25,14 +31,29 @@ Value make_op(std::string op, std::vector args, std::vector mem) { op_instance = &GetContext()->Add(std::unique_ptr(op_instance)); Shape shape = ComputeShape(Value(op_instance)); + bool valid = spec->IsValid(arg_types, output_type); + if (!valid) { + throw std::runtime_error("Invalid operation '" + op + "' with arguments: " + + PrintArray(TransformVector(values_to_ops(args), PrintOp), "[", "]", ", \n")); + } + return {op_instance, spec}; +} +Value value_op(std::string op, Values args, TFDataFormat output_type) { + auto [op_instance, spec] = create_op(op, args, output_type); + if (spec->calc_tuple) throw std::runtime_error("Make op only creates single output operations, use calc_tuple for multi-output ops"); return Value(op_instance); } +Values tuple_op(std::string op, Values args, TFDataFormat output_type) { + auto [op_instance, spec] = create_op(op, args, output_type); + if (!spec->calc_tuple) throw std::runtime_error("Make tuple op only works for operations with multiple outputs"); + return spec->calc_tuple(op_instance, args); +} + Value constant(Attribute value) { - Value const_op = func_op("const"); + Value const_op = value_op("const", {}, GetTypeFromAttribute(value)); const_op.op->attributes["value"] = value; - const_op.op->type = GetTypeFromAttribute(value); return const_op; } @@ -41,41 +62,27 @@ Value constant(uint value) { return constant(Attribute(value)); } Value constant(float value) { return constant(Attribute(value)); } Value constant(bool value) { return constant(Attribute(value)); } -Value unpack_tuple(Value x, int index) { - if (x.op->type != TFTypeTuple) { - throw std::runtime_error("Cannot unpack non-tuple value"); - } - Value elem = func_op("unpack_tuple_int", {x}); - elem.op->attributes["index"] = index; // Default index - return elem; +Value get_output(Value x, int index) { + return Value(x.op, index); } -Value vmap(std::vector shape, std::function body) { - Value par_op = func_op("vmap", shape); - GetContext()->BeginCursor(par_op.op->GetBlock().begin()); - body(par_op); +void vmap(Values shape, std::function body) { + Values indices = tuple_op("vmap", shape); + GetContext()->BeginCursor(indices[0].op->GetBlock().begin()); + body(indices); GetContext()->EndCursor(); - return par_op; } -Value memory(std::vector shape, TFDataFormat type) { - Value mem_op = func_op("memory", std::move(shape)); - mem_op.op->type = type; - return mem_op; +Value memory(Values shape, TFDataFormat type) { + return value_op("memory", std::move(shape), type); } -Value load_at_index(Value mem, std::vector indices) { - if (mem.op->type.type == TFType::None) { - throw std::runtime_error("Cannot load from a None type memory"); - } - return make_op("load", indices, {mem}); +Value load_at_index(Value mem, Values indices) { + return value_op("load", ConcatVectors({mem}, indices)); } void if_cond(Value cond, std::function body_true, std::function body_false) { - if (cond.op->type != TFTypeBool32) { - throw std::runtime_error("Condition must be a boolean value"); - } - Value if_op = func_op("if_cond", {cond}); + Value if_op = value_op("if_cond", {cond}); GetContext()->BeginCursor(if_op.op->GetBlock(0).begin()); body_true(); GetContext()->EndCursor(); @@ -87,20 +94,15 @@ void if_cond(Value cond, std::function body_true, std::function } Value loop(Value start, Value end, Value step, std::function body) { - if (start.op->type != TFTypeInt32 || end.op->type != TFTypeInt32 || step.op->type != TFTypeInt32) { - throw std::runtime_error("Loop indices must be of type int32"); - } - Value loop_op = func_op("loop", {start, end, step}); + Value loop_op = value_op("loop", {start, end, step}); GetContext()->BeginCursor(loop_op.op->GetBlock().begin()); body(loop_op); GetContext()->EndCursor(); return loop_op; } -Value phi(std::vector inputs) { - Value phi_op = func_op("phi", inputs); - phi_op.op->type = inputs.empty() ? TFTypeNone : inputs[0].op->type; // Set type based on first input - return phi_op; +Value phi(Values inputs) { + return value_op("phi", inputs); } } diff --git a/TensorFrost/src/Compiler/Printer.cpp b/TensorFrost/src/Compiler/Printer.cpp index 29045391..bbf99ae4 100644 --- a/TensorFrost/src/Compiler/Printer.cpp +++ b/TensorFrost/src/Compiler/Printer.cpp @@ -15,7 +15,7 @@ std::string VariableName(const Op* op) { std::vector StringifyArguments(const auto_vector>& vec) { return TransformVector(vec, [](const std::unique_ptr& arg) { - return VariableName(arg->from); + return VariableName(arg->from) + (arg->from->output_count > 1 ? "[" + std::to_string(arg->from_index) + "]" : ""); }); } @@ -46,20 +46,6 @@ std::string PrintShape(const Shape& shape) { return PrintArray(dims, "[", "]", ", "); } -std::string PrintArguments(const Arguments* args) { - if (!args) return ""; - std::vector inputs = StringifyArguments(args->inputs); - std::vector outputs; - for (const auto& arg : args->used_at) { - if (arg.second->to) { - outputs.push_back(VariableName(arg.second->to)); - } - } - std::string inputs_str = PrintArray(inputs, "inputs={", "}"); - std::string outputs_str = PrintArray(outputs, "outputs={", "}"); - return "[" + inputs_str + ", " + outputs_str + "]"; -} - std::string PrintAttribute(Attribute attr) { std::ostringstream oss; std::visit([&oss](const auto& v) { oss << v; }, attr); @@ -70,13 +56,10 @@ std::string PrintOp(const Op* op) { std::ostringstream os; os << ToString(op->type) << " " << op->varname; if (op->opcode == "const") { - return ""; - //os << " = " << op->attributes.at("value"); + //return ""; + os << " = " << op->attributes.at("value"); } else { - std::string inputs = PrintArguments(op->args->Get(ArgType::Input)->inputs, "", ""); - std::string index = PrintArguments(op->args->Get(ArgType::Index)->inputs, "index={", "}"); - // std::string inputs = "args=" + PrintArguments(op->args->Get(ArgType::Input)); - // std::string index = "index=" + PrintArguments(op->args->Get(ArgType::Index)); + std::string inputs = PrintArguments(op->args->inputs, "", ""); std::vector attributes; for (const auto& [key, value] : op->attributes) { attributes.push_back(key + ": " + PrintAttribute(value)); @@ -85,7 +68,8 @@ std::string PrintOp(const Op* op) { std::string shape_str = "";// PrintShape(ComputeShape(Value(op))); - os << shape_str << " = " << op->opcode << "(" << PrintArray({inputs, index, attributes_str}) << ")"; + os << shape_str << (op->output_count > 1 ? "[" + std::to_string(op->output_count) + "]" : ""); + os << " = " << op->opcode << "(" << PrintArray({inputs, attributes_str}) << ")"; } return os.str(); } diff --git a/TensorFrost/src/Compiler/TFProgram.cpp b/TensorFrost/src/Compiler/TFProgram.cpp index d05a0281..988349a3 100644 --- a/TensorFrost/src/Compiler/TFProgram.cpp +++ b/TensorFrost/src/Compiler/TFProgram.cpp @@ -1,7 +1,7 @@ #include "Compiler/TFProgram.h" namespace TensorFrost { -TFProgram::TFProgram(std::function, std::vector>()> program_fn) { +TFProgram::TFProgram(std::function()> program_fn) { StartExecutionContext(&context); auto [ins, outs] = program_fn(); @@ -26,7 +26,7 @@ void TFProgram::ConstantFold() { OpSpec* spec = GetOpSpec(op.opcode); if(!spec->const_fold) return; // Skip if no constant folding is defined for this operation AttributeVector inputs; - for (const auto& arg : op.args->Get(ArgType::Input)->inputs) { + for (const auto& arg : op.args->inputs) { if(!arg->from->attributes.contains("value")) { return; // Skip if some argument does not have a constant value } diff --git a/TensorFrost/src/Compiler/Value.cpp b/TensorFrost/src/Compiler/Value.cpp index b3a123c8..6d1501d0 100644 --- a/TensorFrost/src/Compiler/Value.cpp +++ b/TensorFrost/src/Compiler/Value.cpp @@ -5,83 +5,89 @@ using namespace std; namespace TensorFrost { -Value::Value(Op* operation) : op(operation) { +Value::Value(Op* operation, int from_index) : op(operation), out_index(from_index) { if (!op) { throw std::runtime_error("Value cannot be constructed with a null Op pointer"); } + if(from_index >= op->output_count) { + throw std::out_of_range("Output index out of range for the operation"); + } } -Value::Value(const Op *operation) { - if (!operation) { +Value::Value(const Op *operation, int from_index) : out_index(from_index) { + op = const_cast(operation); + if (!op) { throw std::runtime_error("Value cannot be constructed with a null Op pointer"); } - op = const_cast(operation); + if(from_index >= op->output_count) { + throw std::out_of_range("Output index out of range for the operation"); + } } Value::Value(float value) : op(constant(value).op) {} Value::Value(int value) : op(constant(value).op) {} Value::Value(uint value) : op(constant(value).op) {} Value::Value(bool value) : op(constant(value).op) {} - +Value::Value(const Value &other): op(other.op), out_index(other.out_index) {} Value Value::operator+(const Value& other) const { - return func_op("add", {op, other.op}); + return value_op("add", {op, other.op}); } Value Value::operator-(const Value& other) const { - return func_op("sub", {op, other.op}); + return value_op("sub", {op, other.op}); } Value Value::operator*(const Value& other) const { - return func_op("mul", {op, other.op}); + return value_op("mul", {op, other.op}); } Value Value::operator/(const Value& other) const { - return func_op("div", {op, other.op}); + return value_op("div", {op, other.op}); } Value Value::operator%(const Value& other) const { - return func_op("mod", {op, other.op}); + return value_op("mod", {op, other.op}); } Value Value::operator==(const Value& other) const { - return func_op("eq", {op, other.op}); + return value_op("eq", {op, other.op}); } Value Value::operator!=(const Value& other) const { - return func_op("ne", {op, other.op}); + return value_op("ne", {op, other.op}); } Value Value::operator<(const Value& other) const { - return func_op("lt", {op, other.op}); + return value_op("lt", {op, other.op}); } Value Value::operator<=(const Value& other) const { - return func_op("le", {op, other.op}); + return value_op("le", {op, other.op}); } Value Value::operator>(const Value& other) const { - return func_op("gt", {op, other.op}); + return value_op("gt", {op, other.op}); } Value Value::operator>=(const Value& other) const { - return func_op("ge", {op, other.op}); + return value_op("ge", {op, other.op}); } Value Value::operator<<(const Value& other) const { - return func_op("shl", {op, other.op}); + return value_op("shl", {op, other.op}); } Value Value::operator>>(const Value& other) const { - return func_op("shr", {op, other.op}); + return value_op("shr", {op, other.op}); } Value Value::operator&&(const Value& other) const { - return func_op("land", {op, other.op}); + return value_op("land", {op, other.op}); } Value Value::operator||(const Value& other) const { - return func_op("lor", {op, other.op}); + return value_op("lor", {op, other.op}); } Value Value::operator!() const { - return func_op("lnot", {op}); + return value_op("lnot", {op}); } Value Value::operator-() const { - return func_op("neg", {op}); + return value_op("neg", {op}); } Value Value::operator+() const { - return func_op("pos", {op}); + return value_op("pos", {op}); } Value Value::operator~() const { - return func_op("not", {op}); + return value_op("not", {op}); } bool Value::Compare(const Value &other) const { @@ -89,14 +95,11 @@ bool Value::Compare(const Value &other) const { return op->Compare(*other.op); } -Value Value::operator[](int index) const { - return unpack_tuple(*this, index); -} -Value Value::operator[](const std::vector& indices) const { +Value Value::operator[](const Values& indices) const { return load_at_index(*this, indices); } -std::vector values_to_ops(const std::vector& values) { +std::vector values_to_ops(const Values& values) { std::vector ops; ops.reserve(values.size()); for (const auto& value : values) { @@ -109,8 +112,8 @@ std::vector values_to_ops(const std::vector& values) { return ops; } -std::vector ops_to_values(const std::vector& ops) { - std::vector values; +Values ops_to_values(const std::vector& ops) { + Values values; values.reserve(ops.size()); for (const auto& op : ops) { if (op) { @@ -126,7 +129,7 @@ void Shape::AddDimension(const Value &dim) { dimensions.push_back(dim); } -void Shape::AddDimensions(const std::vector &dims) { +void Shape::AddDimensions(const Values &dims) { dimensions.insert(dimensions.end(), dims.begin(), dims.end()); } @@ -155,7 +158,7 @@ Shape ComputeShape(Value x) { for (const auto& parent : parents) { OpSpec* spec = GetOpSpec(parent->opcode); if(spec->props.contains(OpProp::HasShape)) { - shape.AddDimensions(ops_to_values(parent->args->GetInputs(ArgType::Input))); + shape.AddDimensions(parent->args->Inputs()); } } return shape; From d42835a04841192a0611fba376e4e78c08c054d8 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Sat, 16 Aug 2025 06:12:45 +0200 Subject: [PATCH 18/44] Add basic vulkan backend --- .gitignore | 2 + .gitmodules | 6 +- .idea/TensorFrost.iml | 2 +- .idea/vcs.xml | 8 + .run/TensorFrost.run.xml | 2 +- CMakeLists.txt | 3 +- TensorFrost/CMakeLists.txt | 11 +- TensorFrost/PybindModule.cpp | 110 ++++++++-- TensorFrost/include/Backend/Vulkan.h | 52 +++++ TensorFrost/include/TensorFrost.h | 3 +- TensorFrost/src/Backend/Vulkan.cpp | 205 ++++++++++++++++++ .../src/Compiler/OperationRegistry.cpp | 4 +- TensorFrost/src/Compiler/Overloads.cpp | 21 +- TensorFrost/src/Compiler/Printer.cpp | 2 +- TensorFrost/src/Compiler/Value.cpp | 2 +- external/glad | 1 - 16 files changed, 386 insertions(+), 48 deletions(-) create mode 100644 TensorFrost/include/Backend/Vulkan.h create mode 100644 TensorFrost/src/Backend/Vulkan.cpp delete mode 160000 external/glad diff --git a/.gitignore b/.gitignore index 718829be..52156960 100644 --- a/.gitignore +++ b/.gitignore @@ -56,3 +56,5 @@ imgui.ini *.npz /cmake-build-* *.pyc +/.cmake +/CMakeFiles diff --git a/.gitmodules b/.gitmodules index be28e636..fbd707c9 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,9 +4,9 @@ [submodule "external/glfw"] path = external/glfw url = https://github.com/glfw/glfw.git -[submodule "external/glad"] - path = external/glad - url = https://github.com/Dav1dde/glad.git [submodule "external/imgui"] path = external/imgui url = https://github.com/ocornut/imgui.git +[submodule "Python/external/shaderc"] + path = Python/external/shaderc + url = https://github.com/google/shaderc diff --git a/.idea/TensorFrost.iml b/.idea/TensorFrost.iml index 1837251a..83f477b5 100644 --- a/.idea/TensorFrost.iml +++ b/.idea/TensorFrost.iml @@ -2,7 +2,7 @@ - + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml index a8974b61..a0feea62 100644 --- a/.idea/vcs.xml +++ b/.idea/vcs.xml @@ -6,5 +6,13 @@ + + + + + + + + \ No newline at end of file diff --git a/.run/TensorFrost.run.xml b/.run/TensorFrost.run.xml index 97e80504..ba4e1bbd 100644 --- a/.run/TensorFrost.run.xml +++ b/.run/TensorFrost.run.xml @@ -1,5 +1,5 @@ - + diff --git a/CMakeLists.txt b/CMakeLists.txt index 3c927cf8..60c88f13 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,9 +34,10 @@ set(GLFW_BUILD_TESTS OFF CACHE BOOL "" FORCE) set(GLFW_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE) set(PYBIND11_FINDPYTHON ON) +find_package(Vulkan REQUIRED COMPONENTS shaderc_combined) + add_subdirectory(external/pybind11) add_subdirectory(external/glfw) -add_subdirectory(external/glad/cmake) add_subdirectory(TensorFrost) add_subdirectory(examples) diff --git a/TensorFrost/CMakeLists.txt b/TensorFrost/CMakeLists.txt index 384bd775..23da69fd 100644 --- a/TensorFrost/CMakeLists.txt +++ b/TensorFrost/CMakeLists.txt @@ -32,14 +32,9 @@ if(APPLE) ) endif() -# ---- libraries ---- -target_link_libraries(TensorFrost PRIVATE glfw) - -glad_add_library(glad_gl_core_46 SHARED API gl:core=4.6) -target_link_libraries(TensorFrost PRIVATE glad_gl_core_46) - -glad_add_library(glad_vulkan_12 REPRODUCIBLE LOADER API vulkan=1.2) -target_link_libraries(TensorFrost PRIVATE glad_vulkan_12) +target_compile_definitions(TensorFrost PRIVATE VULKAN_HPP_DISPATCH_LOADER_DYNAMIC=1) +target_link_libraries(TensorFrost PRIVATE Vulkan::Vulkan Vulkan::shaderc_combined glfw) +target_include_directories(TensorFrost PRIVATE $ENV{VULKAN_SDK}/Include) if (MSVC) target_compile_options(TensorFrost PRIVATE /wd4804 /wd4805 /wd4018) diff --git a/TensorFrost/PybindModule.cpp b/TensorFrost/PybindModule.cpp index 73ff75b5..fc17b5ce 100644 --- a/TensorFrost/PybindModule.cpp +++ b/TensorFrost/PybindModule.cpp @@ -143,34 +143,100 @@ PYBIND11_MODULE(TensorFrost, m) { Value c = a + b * 3; Value mem = memory({a, b, c}, TFFloat32); inputs.push_back(mem); - vmap({a, b}, [&](Values ids0) { - Value something = tofloat(mem * sin(f + g)); - }); vmap({a, b, c}, [&](Values ids0) { - Value imem = toint(mem * sin(f + g)); - Value d = c + b + ids0[0] * imem; - Value m0, m1; - if_cond(d > 0, [&]() { - Value t = d * c * imem; - vmap({c}, [&](Values ids1) { - m0 = t * imem[{ids1[0], ids0[0], ids0[1]}]; - }); - }, [&]() { - Value t = d * c / imem; - vmap({c}, [&](Values ids1) { - m1 = t / imem[{ids1[0], ids0[0], ids0[1]}]; - }); - }); - Value result = phi({m0, m1}); - vmap({c, c}, [&](Values ids1) { - Value m = result * imem[{ids1[1], ids1[0], ids0[0]}]; - outputs.push_back(m); - }); + Value something = tofloat(mem * sin(f + g)); + outputs.push_back(something); }); + // vmap({a, b, c}, [&](Values ids0) { + // Value imem = toint(mem * sin(f + g)); + // Value d = c + b + ids0[0] * imem; + // Value m0, m1; + // if_cond(d > 0, [&]() { + // Value t = d * c * imem; + // vmap({c}, [&](Values ids1) { + // m0 = t * imem[{ids0[1], ids0[1], ids0[1]}]; + // }); + // }, [&]() { + // Value t = d * c / imem; + // vmap({c}, [&](Values ids1) { + // m1 = t / imem[{ids1[0], ids0[0], ids0[1]}]; + // }); + // }); + // Value result; + // vmap({c}, [&](Values ids1) { + // result = phi({m0, m1}); + // }); + // vmap({c, c}, [&](Values ids1) { + // Value m = result * imem[{ids1[1], ids1[0], ids0[0]}]; + // outputs.push_back(m); + // }); + // }); return std::make_pair(inputs, outputs); }); program.Compile(); py::print(program.DebugPrint()); + + + VulkanContext ctx; + + const size_t N = 1024; + // create buffers + Buffer aBuf = createBuffer(ctx, N, sizeof(float), true); + Buffer bBuf = createBuffer(ctx, N, sizeof(float), true); + Buffer outBuf = createBuffer(ctx, N, sizeof(float), false); + + // map and write input data + float* aPtr = static_cast(ctx.device.mapMemory(aBuf.memory, 0, aBuf.size)); + float* bPtr = static_cast(ctx.device.mapMemory(bBuf.memory, 0, bBuf.size)); + for (size_t i = 0; i < N; i++) { + aPtr[i] = static_cast(i); + bPtr[i] = static_cast(2 * i); + } + ctx.device.unmapMemory(aBuf.memory); + ctx.device.unmapMemory(bBuf.memory); + + // load SPIR-V compute shader (compiled from add.comp) + // The GLSL code: +// +// #version 450 +// layout(local_size_x = 64) in; +// layout(set=0,binding=0) readonly buffer A { float a[]; }; +// layout(set=0,binding=1) readonly buffer B { float b[]; }; +// layout(set=0,binding=2) buffer C { float c[]; }; +// void main() { uint idx = gl_GlobalInvocationID.x; c[idx] = a[idx] + b[idx]; } + std::string code = R"( +#version 450 + layout(local_size_x = 64) in; + layout(set=0,binding=0) readonly buffer A { float a[]; }; + layout(set=0,binding=1) readonly buffer B { float b[]; }; + layout(set=0,binding=2) buffer C { float c[]; }; + void main() { + uint idx = gl_GlobalInvocationID.x; + c[idx] = 2.0f * a[idx] + b[idx]; + } +)"; + ComputeProgram prog = createComputeProgramFromGLSL(ctx, code,{ &aBuf, &bBuf },{ &outBuf }); + + // run compute + runProgram(ctx, prog, static_cast(N)); + + // read back result + float* outPtr = static_cast(ctx.device.mapMemory(outBuf.memory, 0, outBuf.size)); + bool ok = true; + for (size_t i = 0; i < N; i++) { + float expected = 2.0f*aPtr[i] + bPtr[i]; + if (outPtr[i] != expected) { + ok = false; break; + } + } + ctx.device.unmapMemory(outBuf.memory); + py::print("Compute result is ", ok ? "correct" : "incorrect"); + + // cleanup + destroyComputeProgram(ctx, prog); + destroyBuffer(ctx, aBuf); + destroyBuffer(ctx, bBuf); + destroyBuffer(ctx, outBuf); } } // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/include/Backend/Vulkan.h b/TensorFrost/include/Backend/Vulkan.h new file mode 100644 index 00000000..1321163c --- /dev/null +++ b/TensorFrost/include/Backend/Vulkan.h @@ -0,0 +1,52 @@ +#pragma once +#define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1 +#include +#include +#include + +struct Buffer { + vk::Buffer buffer; + vk::DeviceMemory memory; + size_t size; +}; + +struct ComputeProgram { + vk::ShaderModule shaderModule; + vk::DescriptorSetLayout descriptorLayout; + vk::PipelineLayout pipelineLayout; + vk::Pipeline pipeline; + vk::DescriptorPool descriptorPool; + vk::DescriptorSet descriptorSet; +}; + +// Holds instance, physical device, logical device, queue and command pool. +class VulkanContext { +public: + vk::Instance instance; + vk::PhysicalDevice physicalDevice; + vk::Device device; + vk::Queue computeQueue; + uint32_t queueFamilyIndex; + vk::CommandPool commandPool; + + VulkanContext(); + ~VulkanContext(); +}; + +// Creates a host‑visible storage buffer for read‑only or read‑write access. +Buffer createBuffer(VulkanContext& ctx, size_t count, size_t dtypeSize, bool readOnly); + +// Releases the buffer and its memory. +void destroyBuffer(VulkanContext& ctx, Buffer& buf); + +// Compiles a GLSL compute shader and builds a compute pipeline with descriptors. +ComputeProgram createComputeProgramFromGLSL(VulkanContext& ctx, + const std::string& glsl_source, + const std::vector& readonlyBuffers, + const std::vector& readwriteBuffers); + +// Destroys the compute program and associated resources. +void destroyComputeProgram(VulkanContext& ctx, ComputeProgram& prog); + +// Dispatches a compute program with the given number of invocations. +void runProgram(VulkanContext& ctx, ComputeProgram& prog, uint32_t numInvocations); \ No newline at end of file diff --git a/TensorFrost/include/TensorFrost.h b/TensorFrost/include/TensorFrost.h index 13528d8d..c1165ff7 100644 --- a/TensorFrost/include/TensorFrost.h +++ b/TensorFrost/include/TensorFrost.h @@ -7,4 +7,5 @@ #include "Compiler/Overloads.h" #include "Compiler/Value.h" #include "Compiler/Printer.h" -#include "Compiler/TFProgram.h" \ No newline at end of file +#include "Compiler/TFProgram.h" +#include "Backend/Vulkan.h" \ No newline at end of file diff --git a/TensorFrost/src/Backend/Vulkan.cpp b/TensorFrost/src/Backend/Vulkan.cpp new file mode 100644 index 00000000..e0426da4 --- /dev/null +++ b/TensorFrost/src/Backend/Vulkan.cpp @@ -0,0 +1,205 @@ +#include "Backend/Vulkan.h" +VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE +#include +#include + +// compile GLSL to SPIR-V at runtime +static std::vector compileGLSLToSpirv(const std::string& source) { + shaderc::Compiler compiler; + shaderc::CompileOptions opts; + opts.SetTargetEnvironment(shaderc_target_env_vulkan, + shaderc_env_version_vulkan_1_1); + shaderc::SpvCompilationResult result = + compiler.CompileGlslToSpv(source, shaderc_compute_shader, "shader", opts); + if (result.GetCompilationStatus() != shaderc_compilation_status_success) { + throw std::runtime_error(result.GetErrorMessage()); + } + return {result.cbegin(), result.cend()}; +} + +// VulkanContext constructor sets up instance, selects a compute device and queue, and creates a command pool. +VulkanContext::VulkanContext() { + VULKAN_HPP_DEFAULT_DISPATCHER.init(vkGetInstanceProcAddr); // required before vk::createInstance + + // 1) instance + vk::ApplicationInfo appInfo("ComputeFramework", 1, nullptr, 0, VK_API_VERSION_1_1); + vk::InstanceCreateInfo instCreate({}, &appInfo); + instance = vk::createInstance(instCreate); + VULKAN_HPP_DEFAULT_DISPATCHER.init(instance); // load instance-level funcs + + // 2) pick physical device + compute queue family + auto devices = instance.enumeratePhysicalDevices(); + if (devices.empty()) throw std::runtime_error("No physical devices"); + for (auto& pd : devices) { + auto q = pd.getQueueFamilyProperties(); + for (uint32_t i = 0; i < q.size(); ++i) { + if ( (q[i].queueFlags & vk::QueueFlagBits::eCompute) != vk::QueueFlags{} ) { + physicalDevice = pd; + queueFamilyIndex = i; + break; + } + } + if (physicalDevice) break; + } + if (!physicalDevice) throw std::runtime_error("No compute queue"); + + // 3) device + queue + float prio = 1.0f; + vk::DeviceQueueCreateInfo qci({}, queueFamilyIndex, 1, &prio); + vk::DeviceCreateInfo devCreate({}, qci); + device = physicalDevice.createDevice(devCreate); + VULKAN_HPP_DEFAULT_DISPATCHER.init(device); // load device-level funcs + + computeQueue = device.getQueue(queueFamilyIndex, 0); + + // 4) command pool + vk::CommandPoolCreateInfo poolInfo({}, queueFamilyIndex); + commandPool = device.createCommandPool(poolInfo); +} + +// VulkanContext destructor cleans up the command pool, device and instance. +VulkanContext::~VulkanContext() { + device.destroyCommandPool(commandPool); + device.destroy(); + instance.destroy(); +} + +// create a storage buffer +Buffer createBuffer(VulkanContext& ctx, size_t count, size_t dtypeSize, bool readOnly) { + Buffer buf; + buf.size = count * dtypeSize; + vk::BufferCreateInfo bci({}, buf.size, + vk::BufferUsageFlagBits::eStorageBuffer | + vk::BufferUsageFlagBits::eTransferSrc | + vk::BufferUsageFlagBits::eTransferDst); + buf.buffer = ctx.device.createBuffer(bci); + auto memReq = ctx.device.getBufferMemoryRequirements(buf.buffer); + + auto memProps = ctx.physicalDevice.getMemoryProperties(); + uint32_t memTypeIndex = 0; + for (uint32_t i = 0; i < memProps.memoryTypeCount; i++) { + bool allowed = memReq.memoryTypeBits & (1u << i); + auto typeBits = memReq.memoryTypeBits; + auto flags = memProps.memoryTypes[i].propertyFlags; + + bool ok = (typeBits & (1u << i)) != 0; + bool hostVis = (flags & vk::MemoryPropertyFlagBits::eHostVisible) != vk::MemoryPropertyFlags{}; + if (allowed && hostVis) { + memTypeIndex = i; + break; + } + } + vk::MemoryAllocateInfo allocInfo(memReq.size, memTypeIndex); + buf.memory = ctx.device.allocateMemory(allocInfo); + ctx.device.bindBufferMemory(buf.buffer, buf.memory, 0); + return buf; +} + +void destroyBuffer(VulkanContext& ctx, Buffer& buf) { + ctx.device.destroyBuffer(buf.buffer); + ctx.device.freeMemory(buf.memory); + buf.buffer = nullptr; + buf.memory = nullptr; +} + +// internal helper to build a compute program from SPIR-V +static ComputeProgram createComputeProgram(VulkanContext& ctx, + const std::vector& spirv, + const std::vector& readonlyBuffers, + const std::vector& readwriteBuffers) { + + ComputeProgram prog; + vk::ShaderModuleCreateInfo smci({}, spirv.size() * sizeof(uint32_t), spirv.data()); + prog.shaderModule = ctx.device.createShaderModule(smci); + + std::vector bindings; + uint32_t binding = 0; + for (size_t i = 0; i < readonlyBuffers.size(); i++) { + bindings.emplace_back(binding++, vk::DescriptorType::eStorageBuffer, 1, + vk::ShaderStageFlagBits::eCompute); + } + for (size_t i = 0; i < readwriteBuffers.size(); i++) { + bindings.emplace_back(binding++, vk::DescriptorType::eStorageBuffer, 1, + vk::ShaderStageFlagBits::eCompute); + } + vk::DescriptorSetLayoutCreateInfo dsInfo({}, bindings.size(), bindings.data()); + prog.descriptorLayout = ctx.device.createDescriptorSetLayout(dsInfo); + vk::PipelineLayoutCreateInfo plInfo({}, 1, &prog.descriptorLayout); + prog.pipelineLayout = ctx.device.createPipelineLayout(plInfo); + + vk::PipelineShaderStageCreateInfo stageInfo({}, vk::ShaderStageFlagBits::eCompute, + prog.shaderModule, "main"); + vk::ComputePipelineCreateInfo cpInfo({}, stageInfo, prog.pipelineLayout); + prog.pipeline = ctx.device.createComputePipeline({}, cpInfo).value; + + vk::DescriptorPoolSize poolSize(vk::DescriptorType::eStorageBuffer, + readonlyBuffers.size() + readwriteBuffers.size()); + vk::DescriptorPoolCreateInfo poolInfo({}, 1, 1, &poolSize); + prog.descriptorPool = ctx.device.createDescriptorPool(poolInfo); + vk::DescriptorSetAllocateInfo allocInfo(prog.descriptorPool, 1, &prog.descriptorLayout); + prog.descriptorSet = ctx.device.allocateDescriptorSets(allocInfo)[0]; + + std::vector bufferInfos; + bufferInfos.reserve(readonlyBuffers.size() + readwriteBuffers.size()); + for (auto b : readonlyBuffers) { + bufferInfos.push_back(vk::DescriptorBufferInfo(b->buffer, 0, b->size)); + } + for (auto b : readwriteBuffers) { + bufferInfos.push_back(vk::DescriptorBufferInfo(b->buffer, 0, b->size)); + } + std::vector writes; + for (uint32_t i = 0; i < bufferInfos.size(); i++) { + vk::WriteDescriptorSet w(prog.descriptorSet, i, 0, 1, + vk::DescriptorType::eStorageBuffer, nullptr, + &bufferInfos[i]); + writes.push_back(w); + } + ctx.device.updateDescriptorSets(writes, {}); + return prog; +} + +// public wrapper that compiles GLSL and builds the program +ComputeProgram createComputeProgramFromGLSL(VulkanContext& ctx, + const std::string& glsl_source, + const std::vector& readonlyBuffers, + const std::vector& readwriteBuffers) { + + auto spirv = compileGLSLToSpirv(glsl_source); + return createComputeProgram(ctx, spirv, readonlyBuffers, readwriteBuffers); +} + +void destroyComputeProgram(VulkanContext& ctx, ComputeProgram& prog) { + ctx.device.destroyDescriptorPool(prog.descriptorPool); + ctx.device.destroyPipeline(prog.pipeline); + ctx.device.destroyPipelineLayout(prog.pipelineLayout); + ctx.device.destroyDescriptorSetLayout(prog.descriptorLayout); + ctx.device.destroyShaderModule(prog.shaderModule); + prog.pipeline = nullptr; + prog.pipelineLayout = nullptr; + prog.descriptorLayout = nullptr; + prog.descriptorPool = nullptr; + prog.shaderModule = nullptr; +} + +// dispatch compute commands +void runProgram(VulkanContext& ctx, ComputeProgram& prog, uint32_t n) { + vk::CommandBufferAllocateInfo ai(ctx.commandPool, vk::CommandBufferLevel::ePrimary, 1); + auto cmd = ctx.device.allocateCommandBuffers(ai)[0]; + + cmd.begin(vk::CommandBufferBeginInfo{}); + cmd.bindPipeline(vk::PipelineBindPoint::eCompute, prog.pipeline); + cmd.bindDescriptorSets(vk::PipelineBindPoint::eCompute, prog.pipelineLayout, 0, prog.descriptorSet, {}); + uint32_t gs = 64, groups = (n + gs - 1) / gs; + cmd.dispatch(groups, 1, 1); + cmd.end(); + + vk::Fence fence = ctx.device.createFence({}); + ctx.computeQueue.submit(vk::SubmitInfo(0, nullptr, 0, 1, &cmd), fence); + + [[maybe_unused]] auto rWait = ctx.device.waitForFences(fence, VK_TRUE, UINT64_MAX); + // if you compile with VULKAN_HPP_NO_EXCEPTIONS=1, you can check: + // if (rWait != vk::Result::eSuccess) throw std::runtime_error("waitForFences failed"); + + ctx.device.destroyFence(fence); + ctx.device.freeCommandBuffers(ctx.commandPool, cmd); +} \ No newline at end of file diff --git a/TensorFrost/src/Compiler/OperationRegistry.cpp b/TensorFrost/src/Compiler/OperationRegistry.cpp index 4b91d002..4a3ac15f 100644 --- a/TensorFrost/src/Compiler/OperationRegistry.cpp +++ b/TensorFrost/src/Compiler/OperationRegistry.cpp @@ -43,7 +43,7 @@ bool ArgSpec::IsValid(std::vector inputs, TFDataFormat output) con for (size_t i = 0; i < inputs.size(); ++i) { const auto& n = name_of(i); if (seen.count(n) && !(seen[n] == inputs[i])) - throw std::runtime_error("Conflicting types for arg " + n); + return false; // Conflicting types for arg seen[n] = inputs[i]; auto a = types.find(n); @@ -71,7 +71,7 @@ TFDataFormat ArgSpec::EstimateOutputType(const std::vector &inputs for (size_t i = 0; i < inputs.size(); ++i) { const auto& n = name_of(i); if (seen.count(n) && !(seen[n] == inputs[i])) - throw std::runtime_error("Conflicting types for arg " + n); + return TFUnknown; // Conflicting types for arg seen[n] = inputs[i]; } diff --git a/TensorFrost/src/Compiler/Overloads.cpp b/TensorFrost/src/Compiler/Overloads.cpp index 352e3be2..eef5a224 100644 --- a/TensorFrost/src/Compiler/Overloads.cpp +++ b/TensorFrost/src/Compiler/Overloads.cpp @@ -17,9 +17,9 @@ std::pair create_op(std::string op, const Values& args, TFDataForm if (output_type == TFUnknown) { output_type = spec->GetOutputType(arg_types); } - if (output_type == TFUnknown) { - throw std::runtime_error("Cannot determine output type for operation '" + op + "'"); - } + // if (output_type == TFUnknown) { + // throw std::runtime_error("Cannot determine output type for operation '" + op + "'"); + // } Op* op_instance = new Op(op); op_instance->type = output_type; op_instance->args->SetArguments(args); @@ -29,13 +29,22 @@ std::pair create_op(std::string op, const Values& args, TFDataForm op_instance->NewBlock(); } op_instance = &GetContext()->Add(std::unique_ptr(op_instance)); - Shape shape = ComputeShape(Value(op_instance)); - bool valid = spec->IsValid(arg_types, output_type); if (!valid) { - throw std::runtime_error("Invalid operation '" + op + "' with arguments: " + + throw std::runtime_error("Invalid operation types for '" + op + "' with arguments: " + PrintArray(TransformVector(values_to_ops(args), PrintOp), "[", "]", ", \n")); } + + // Check if the shape is compatible + Shape shape = ComputeShape(Value(op_instance)); + std::vector> input_props = spec->arg_spec.InputProperties(arg_types); + for (size_t i = 0; i < args.size(); ++i) { + if (input_props[i].contains(ArgProp::IgnoreShape)) continue; + if (!shape.Broadcastable(ComputeShape(args[i]))) { + throw std::runtime_error("Incompatible shape for argument " + to_string(i) + " in operation '" + op + "'"); + } + } + return {op_instance, spec}; } diff --git a/TensorFrost/src/Compiler/Printer.cpp b/TensorFrost/src/Compiler/Printer.cpp index bbf99ae4..39a55c50 100644 --- a/TensorFrost/src/Compiler/Printer.cpp +++ b/TensorFrost/src/Compiler/Printer.cpp @@ -56,7 +56,7 @@ std::string PrintOp(const Op* op) { std::ostringstream os; os << ToString(op->type) << " " << op->varname; if (op->opcode == "const") { - //return ""; + //return ""; os << " = " << op->attributes.at("value"); } else { std::string inputs = PrintArguments(op->args->inputs, "", ""); diff --git a/TensorFrost/src/Compiler/Value.cpp b/TensorFrost/src/Compiler/Value.cpp index 6d1501d0..3a02c1f7 100644 --- a/TensorFrost/src/Compiler/Value.cpp +++ b/TensorFrost/src/Compiler/Value.cpp @@ -136,7 +136,7 @@ void Shape::AddDimensions(const Values &dims) { bool Shape::Broadcastable(const Shape &other) const { size_t size = other.dimensions.size(); if (dimensions.size() < size) { - throw std::runtime_error("Other shape has more dimensions than this shape"); + return false; // Cannot broadcast if this shape has fewer dimensions } for (size_t i = 0; i < size; ++i) { if (!dimensions[i].Compare(other.dimensions[i])) { diff --git a/external/glad b/external/glad deleted file mode 160000 index 658f48e7..00000000 --- a/external/glad +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 658f48e72aee3c6582e80b05ac0f8787a64fe6bb From e26f6a42ad8d8e67df57ca3912e8be41f0f69b15 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Sun, 24 Aug 2025 06:37:38 +0200 Subject: [PATCH 19/44] Basic vulkan binds on python --- .idea/vcs.xml | 9 - README.md | 2 +- TensorFrost/CMakeLists.txt | 27 ++- TensorFrost/PybindModule.cpp | 210 +++++++++++------- TensorFrost/include/Backend/Vulkan.h | 3 + .../include/Compiler/OperationRegistry.h | 3 + TensorFrost/include/Compiler/Value.h | 4 +- TensorFrost/src/Backend/Vulkan.cpp | 67 ++++++ .../src/Compiler/OperationRegistry.cpp | 15 +- TensorFrost/src/Compiler/Value.cpp | 6 + examples/debug.py | 46 ++++ 11 files changed, 295 insertions(+), 97 deletions(-) diff --git a/.idea/vcs.xml b/.idea/vcs.xml index a0feea62..a48f14df 100644 --- a/.idea/vcs.xml +++ b/.idea/vcs.xml @@ -2,17 +2,8 @@ - - - - - - - - - \ No newline at end of file diff --git a/README.md b/README.md index dd9a14a3..4f0798c9 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ pip install tensorfrost ## From source -You need to have CMake installed to build the library. +You need to have CMake and Vulkan SDK installed to build the library. First clone the repository: ```bash diff --git a/TensorFrost/CMakeLists.txt b/TensorFrost/CMakeLists.txt index 23da69fd..f6743c9d 100644 --- a/TensorFrost/CMakeLists.txt +++ b/TensorFrost/CMakeLists.txt @@ -32,8 +32,19 @@ if(APPLE) ) endif() +find_library(SLANG_LIB_RELEASE NAMES slang PATHS $ENV{VULKAN_SDK}/Lib) +find_library(SLANG_LIB_DEBUG NAMES slangd slang PATHS $ENV{VULKAN_SDK}/Lib) + +if (NOT SLANG_LIB_RELEASE) + message(FATAL_ERROR "slang.lib not found in %VULKAN_SDK%/Lib") +endif() + target_compile_definitions(TensorFrost PRIVATE VULKAN_HPP_DISPATCH_LOADER_DYNAMIC=1) target_link_libraries(TensorFrost PRIVATE Vulkan::Vulkan Vulkan::shaderc_combined glfw) +target_link_libraries(TensorFrost PRIVATE + $<$:${SLANG_LIB_DEBUG}> + $<$>:${SLANG_LIB_RELEASE}> +) target_include_directories(TensorFrost PRIVATE $ENV{VULKAN_SDK}/Include) if (MSVC) @@ -49,7 +60,7 @@ file(GLOB IMGUI_SOURCE_LIST ${CMAKE_SOURCE_DIR}/external/imgui/*.cpp) file(GLOB IMGUI_BACKEND_SOURCE_LIST ${CMAKE_SOURCE_DIR}/external/imgui/backends/imgui_impl_glfw.cpp - ${CMAKE_SOURCE_DIR}/external/imgui/backends/imgui_impl_opengl3.cpp) + ${CMAKE_SOURCE_DIR}/external/imgui/backends/imgui_impl_vulkan.cpp) target_sources(TensorFrost PRIVATE ${IMGUI_SOURCE_LIST} ${IMGUI_BACKEND_SOURCE_LIST}) # ---- RenderDoc headers ---- @@ -74,4 +85,18 @@ set_target_properties(TensorFrost PROPERTIES VS_DEBUGGER_COMMAND_ARGUMENTS "${DEBUG_PYTHON_SCRIPT}" VS_DEBUGGER_WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" LINK_FLAGS_RELWITHDEBINFO "/PROFILE" +) + +set(VKSDK_BIN "$ENV{VULKAN_SDK}/Bin") + +add_custom_command(TARGET TensorFrost POST_BUILD + COMMAND ${CMAKE_COMMAND} -E make_directory "$" + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${VKSDK_BIN}/slang$<$:d>.dll" "$/slang$<$:d>.dll" + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${VKSDK_BIN}/slang-glslang$<$:d>.dll" "$/slang-glslang$<$:d>.dll" + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${VKSDK_BIN}/slang-glsl-module$<$:d>.dll" "$/slang-glsl-module$<$:d>.dll" + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${VKSDK_BIN}/glslang$<$:d>.dll" "$/glslang$<$:d>.dll" + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${VKSDK_BIN}/glslang-default-resource-limits$<$:d>.dll" "$/glslang-default-resource-limits$<$:d>.dll" + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${VKSDK_BIN}/SPIRV-Tools-shared$<$:d>.dll" "$/SPIRV-Tools-shared$<$:d>.dll" + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${VKSDK_BIN}/SPIRV$<$:d>.dll" "$/SPIRV$<$:d>.dll" + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${VKSDK_BIN}/SPVRemapper$<$:d>.dll" "$/SPVRemapper$<$:d>.dll" ) \ No newline at end of file diff --git a/TensorFrost/PybindModule.cpp b/TensorFrost/PybindModule.cpp index fc17b5ce..a4740e46 100644 --- a/TensorFrost/PybindModule.cpp +++ b/TensorFrost/PybindModule.cpp @@ -24,6 +24,37 @@ namespace TensorFrost { // void ScopeDefinitions(py::module& m, py::class_& py_tensor); // void ModuleDefinitions(py::module& m); +struct PyDevice { + vk::Device* dev{}; + explicit PyDevice(vk::Device* d) : dev(d) {} + + py::memoryview mapMemory(const vk::DeviceMemory& mem, uint64_t offset, uint64_t size, bool readonly=false) { + if (size == 0) throw py::value_error("size==0"); + if (size > static_cast(std::numeric_limits::max())) + throw py::value_error("size too large"); + + void* p = nullptr; + { + py::gil_scoped_release nogil; + p = dev->mapMemory(mem, offset, size); + } + if (!p) throw py::value_error("vkMapMemory returned nullptr"); + + PyObject* raw = PyMemoryView_FromMemory( + reinterpret_cast(p), + static_cast(size), + readonly ? PyBUF_READ : PyBUF_WRITE); + if (!raw) throw py::error_already_set(); + + return py::reinterpret_steal(raw); + } + + void unmapMemory(const vk::DeviceMemory& mem) { + py::gil_scoped_release nogil; + dev->unmapMemory(mem); + } +}; + PYBIND11_MODULE(TensorFrost, m) { m.doc() = "TensorFrost library"; // auto data_type = py::enum_(m, "TFType"); @@ -147,96 +178,115 @@ PYBIND11_MODULE(TensorFrost, m) { Value something = tofloat(mem * sin(f + g)); outputs.push_back(something); }); - // vmap({a, b, c}, [&](Values ids0) { - // Value imem = toint(mem * sin(f + g)); - // Value d = c + b + ids0[0] * imem; - // Value m0, m1; - // if_cond(d > 0, [&]() { - // Value t = d * c * imem; - // vmap({c}, [&](Values ids1) { - // m0 = t * imem[{ids0[1], ids0[1], ids0[1]}]; - // }); - // }, [&]() { - // Value t = d * c / imem; - // vmap({c}, [&](Values ids1) { - // m1 = t / imem[{ids1[0], ids0[0], ids0[1]}]; - // }); - // }); - // Value result; - // vmap({c}, [&](Values ids1) { - // result = phi({m0, m1}); - // }); - // vmap({c, c}, [&](Values ids1) { - // Value m = result * imem[{ids1[1], ids1[0], ids0[0]}]; - // outputs.push_back(m); - // }); - // }); + vmap({a, b, c}, [&](Values ids0) { + Value imem = toint(mem * sin(f + g)); + Value d = c + b + ids0[0] * imem; + Value m0; + vmap({c}, [&](Values ids1) { + m0 = 0; + }); + if_cond(d > 0, [&]() { + Value t = d * c * imem; + vmap({c}, [&](Values ids1) { + m0.Set(t * imem[{ids0[1], ids0[1], ids0[1]}]); + }); + }, [&]() { + Value t = d * c / imem; + vmap({c}, [&](Values ids1) { + m0.Set(t / imem[{ids1[0], ids0[0], ids0[1]}]); + }); + }); + vmap({c, c}, [&](Values ids1) { + Value m = m0 * imem[{ids1[1], ids1[0], ids0[0]}]; + outputs.push_back(m); + }); + }); return std::make_pair(inputs, outputs); }); program.Compile(); py::print(program.DebugPrint()); + py::class_(m, "DeviceMemory"); + py::class_(m, "Device") + .def("mapMemory", &PyDevice::mapMemory, + py::arg("memory"), py::arg("offset"), py::arg("size"), py::arg("readonly") = false) + .def("unmapMemory", &PyDevice::unmapMemory, py::arg("memory")); - VulkanContext ctx; + py::class_(m, "VulkanContext") + .def(py::init<>()) + .def_property_readonly("device", + [](VulkanContext& c){ return PyDevice(&c.device); }, + py::return_value_policy::reference_internal); - const size_t N = 1024; - // create buffers - Buffer aBuf = createBuffer(ctx, N, sizeof(float), true); - Buffer bBuf = createBuffer(ctx, N, sizeof(float), true); - Buffer outBuf = createBuffer(ctx, N, sizeof(float), false); + py::class_(m, "Buffer") + .def_readonly("memory", &Buffer::memory) + .def_property_readonly("size", [](const Buffer& b){ return b.size; }); - // map and write input data - float* aPtr = static_cast(ctx.device.mapMemory(aBuf.memory, 0, aBuf.size)); - float* bPtr = static_cast(ctx.device.mapMemory(bBuf.memory, 0, bBuf.size)); - for (size_t i = 0; i < N; i++) { - aPtr[i] = static_cast(i); - bPtr[i] = static_cast(2 * i); - } - ctx.device.unmapMemory(aBuf.memory); - ctx.device.unmapMemory(bBuf.memory); + py::class_(m, "ComputeProgram"); - // load SPIR-V compute shader (compiled from add.comp) - // The GLSL code: + m.def("createBuffer", &createBuffer, py::arg("ctx"), py::arg("count"), py::arg("dtypeSize"), py::arg("readOnly"), + py::return_value_policy::move); + m.def("destroyBuffer", &destroyBuffer, py::arg("ctx"), py::arg("buf")); + m.def("createComputeProgramFromGLSL", &createComputeProgramFromGLSL, + py::arg("ctx"), py::arg("glsl_source"), py::arg("readonlyBuffers"), py::arg("readwriteBuffers"), + py::return_value_policy::move, py::keep_alive<0,3>(), py::keep_alive<0,4>()); + m.def("createComputeProgramFromSlang", &createComputeProgramFromSlang, + py::arg("ctx"), py::arg("moduleName"), py::arg("source"), py::arg("entry"), + py::arg("readonlyBuffers"), py::arg("readwriteBuffers"), + py::return_value_policy::move, py::keep_alive<0,5>(), py::keep_alive<0,6>()); + m.def("destroyComputeProgram", &destroyComputeProgram, py::arg("ctx"), py::arg("prog")); + m.def("runProgram", &runProgram, py::arg("ctx"), py::arg("prog"), py::arg("numInvocations"), + py::call_guard()); +// VulkanContext ctx; // -// #version 450 -// layout(local_size_x = 64) in; -// layout(set=0,binding=0) readonly buffer A { float a[]; }; -// layout(set=0,binding=1) readonly buffer B { float b[]; }; -// layout(set=0,binding=2) buffer C { float c[]; }; -// void main() { uint idx = gl_GlobalInvocationID.x; c[idx] = a[idx] + b[idx]; } - std::string code = R"( -#version 450 - layout(local_size_x = 64) in; - layout(set=0,binding=0) readonly buffer A { float a[]; }; - layout(set=0,binding=1) readonly buffer B { float b[]; }; - layout(set=0,binding=2) buffer C { float c[]; }; - void main() { - uint idx = gl_GlobalInvocationID.x; - c[idx] = 2.0f * a[idx] + b[idx]; - } -)"; - ComputeProgram prog = createComputeProgramFromGLSL(ctx, code,{ &aBuf, &bBuf },{ &outBuf }); - - // run compute - runProgram(ctx, prog, static_cast(N)); - - // read back result - float* outPtr = static_cast(ctx.device.mapMemory(outBuf.memory, 0, outBuf.size)); - bool ok = true; - for (size_t i = 0; i < N; i++) { - float expected = 2.0f*aPtr[i] + bPtr[i]; - if (outPtr[i] != expected) { - ok = false; break; - } - } - ctx.device.unmapMemory(outBuf.memory); - py::print("Compute result is ", ok ? "correct" : "incorrect"); - - // cleanup - destroyComputeProgram(ctx, prog); - destroyBuffer(ctx, aBuf); - destroyBuffer(ctx, bBuf); - destroyBuffer(ctx, outBuf); +// const size_t N = 1024; +// // create buffers +// Buffer aBuf = createBuffer(ctx, N, sizeof(float), true); +// Buffer bBuf = createBuffer(ctx, N, sizeof(float), true); +// Buffer outBuf = createBuffer(ctx, N, sizeof(float), false); +// +// // map and write input data +// float* aPtr = static_cast(ctx.device.mapMemory(aBuf.memory, 0, aBuf.size)); +// float* bPtr = static_cast(ctx.device.mapMemory(bBuf.memory, 0, bBuf.size)); +// for (size_t i = 0; i < N; i++) { +// aPtr[i] = static_cast(i); +// bPtr[i] = static_cast(2 * i); +// } +// ctx.device.unmapMemory(aBuf.memory); +// ctx.device.unmapMemory(bBuf.memory); +// +// std::string code = R"( +// [[vk::binding(2,0)]] RWStructuredBuffer C; +// [[vk::binding(0,0)]] StructuredBuffer A; +// [[vk::binding(1,0)]] StructuredBuffer B; +// +// [shader("compute")] [numthreads(64,1,1)] +// void computeMain(uint3 tid: SV_DispatchThreadID) { +// C[tid.x] = 2.0f*A[tid.x] + B[tid.x]; +// } +// )"; +// ComputeProgram prog = createComputeProgramFromSlang(ctx, "vecadd", code, "computeMain", { &aBuf, &bBuf },{ &outBuf }); +// +// // run compute +// runProgram(ctx, prog, N); +// +// // read back result +// float* outPtr = static_cast(ctx.device.mapMemory(outBuf.memory, 0, outBuf.size)); +// bool ok = true; +// for (size_t i = 0; i < N; i++) { +// float expected = 2.0f*aPtr[i] + bPtr[i]; +// if (outPtr[i] != expected) { +// ok = false; break; +// } +// } +// ctx.device.unmapMemory(outBuf.memory); +// py::print("Compute result is ", ok ? "correct" : "incorrect"); +// +// // cleanup +// destroyComputeProgram(ctx, prog); +// destroyBuffer(ctx, aBuf); +// destroyBuffer(ctx, bBuf); +// destroyBuffer(ctx, outBuf); } } // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/include/Backend/Vulkan.h b/TensorFrost/include/Backend/Vulkan.h index 1321163c..2bff292a 100644 --- a/TensorFrost/include/Backend/Vulkan.h +++ b/TensorFrost/include/Backend/Vulkan.h @@ -45,6 +45,9 @@ ComputeProgram createComputeProgramFromGLSL(VulkanContext& ctx, const std::vector& readonlyBuffers, const std::vector& readwriteBuffers); +ComputeProgram createComputeProgramFromSlang(VulkanContext& ctx, const std::string& moduleName, + const std::string& source, const std::string& entry, const std::vector& readonlyBuffers, const std::vector& readwriteBuffers); + // Destroys the compute program and associated resources. void destroyComputeProgram(VulkanContext& ctx, ComputeProgram& prog); diff --git a/TensorFrost/include/Compiler/OperationRegistry.h b/TensorFrost/include/Compiler/OperationRegistry.h index d21eb599..26f2289a 100644 --- a/TensorFrost/include/Compiler/OperationRegistry.h +++ b/TensorFrost/include/Compiler/OperationRegistry.h @@ -17,6 +17,7 @@ enum class OpClass { TernaryOperator, Memory, Phi, + Set, None, }; @@ -82,6 +83,8 @@ struct ArgSpec { std::map> types = {}, std::map> props = {}); + bool ValidArgCount(size_t count) const; + bool IsValid(std::vector inputs, TFDataFormat output) const; TFDataFormat EstimateOutputType(const std::vector& inputs) const; std::vector> InputProperties(const std::vector& inputs) const; diff --git a/TensorFrost/include/Compiler/Value.h b/TensorFrost/include/Compiler/Value.h index a6624d89..11516113 100644 --- a/TensorFrost/include/Compiler/Value.h +++ b/TensorFrost/include/Compiler/Value.h @@ -3,7 +3,7 @@ namespace TensorFrost { -// Op wrapper class for overloaded mathematics and manipulations +// Op thin wrapper class for overloaded mathematics and manipulations class Value { public: Op* op = nullptr; @@ -45,6 +45,8 @@ class Value { Value operator~() const; bool Compare(const Value& other) const; + + void Set(Value value); }; std::vector values_to_ops(const Values& values); diff --git a/TensorFrost/src/Backend/Vulkan.cpp b/TensorFrost/src/Backend/Vulkan.cpp index e0426da4..78a1d46f 100644 --- a/TensorFrost/src/Backend/Vulkan.cpp +++ b/TensorFrost/src/Backend/Vulkan.cpp @@ -1,6 +1,8 @@ #include "Backend/Vulkan.h" VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE #include +#include +#include #include // compile GLSL to SPIR-V at runtime @@ -17,6 +19,65 @@ static std::vector compileGLSLToSpirv(const std::string& source) { return {result.cbegin(), result.cend()}; } +std::vector compileSlangToSpirv(const char* moduleName, + const char* source, + const char* entry, + const char* profile /* e.g., "spirv_1_5" */) { + Slang::ComPtr global; + createGlobalSession(global.writeRef()); + + slang::TargetDesc tgt{}; + tgt.format = SLANG_SPIRV; + tgt.profile = global->findProfile(profile); + + slang::SessionDesc sd{}; + sd.targets = &tgt; sd.targetCount = 1; + + Slang::ComPtr session; + global->createSession(sd, session.writeRef()); + + Slang::ComPtr diag; + Slang::ComPtr mod; + mod = session->loadModuleFromSourceString(moduleName, moduleName, source, diag.writeRef()); + if (diag && diag->getBufferSize()) std::fprintf(stderr, "%s\n", (const char*)diag->getBufferPointer()); + if (!mod) throw std::runtime_error("slang: module load failed"); + + Slang::ComPtr ep; + mod->findEntryPointByName(entry, ep.writeRef()); + if (!ep) throw std::runtime_error("slang: entry not found"); + + slang::IComponentType* parts[] = { mod.get(), ep.get() }; + Slang::ComPtr composed, linked; + + { + Slang::ComPtr d; + SlangResult r = session->createCompositeComponentType(parts, 2, composed.writeRef(), d.writeRef()); + if (d && d->getBufferSize()) std::fprintf(stderr, "%s\n", (const char*)d->getBufferPointer()); + if (SLANG_FAILED(r)) throw std::runtime_error("slang: compose failed"); + } + { + Slang::ComPtr d; + SlangResult r = composed->link(linked.writeRef(), d.writeRef()); + if (d && d->getBufferSize()) std::fprintf(stderr, "%s\n", (const char*)d->getBufferPointer()); + if (SLANG_FAILED(r)) throw std::runtime_error("slang: link failed"); + } + + Slang::ComPtr spirv; + { + Slang::ComPtr d; + SlangResult r = linked->getEntryPointCode(0, 0, spirv.writeRef(), d.writeRef()); + if (d && d->getBufferSize()) std::fprintf(stderr, "%s\n", (const char*)d->getBufferPointer()); + if (SLANG_FAILED(r)) throw std::runtime_error("slang: getEntryPointCode failed"); + } + + size_t n = spirv->getBufferSize(); + auto* p = static_cast(spirv->getBufferPointer()); + std::vector out((n + 3) / 4); + std::memcpy(out.data(), p, n); + return out; +} + + // VulkanContext constructor sets up instance, selects a compute device and queue, and creates a command pool. VulkanContext::VulkanContext() { VULKAN_HPP_DEFAULT_DISPATCHER.init(vkGetInstanceProcAddr); // required before vk::createInstance @@ -168,6 +229,12 @@ ComputeProgram createComputeProgramFromGLSL(VulkanContext& ctx, return createComputeProgram(ctx, spirv, readonlyBuffers, readwriteBuffers); } +ComputeProgram createComputeProgramFromSlang(VulkanContext& ctx, const std::string& moduleName, + const std::string& source, const std::string& entry, const std::vector& readonlyBuffers, const std::vector& readwriteBuffers) { + auto spirv = compileSlangToSpirv(moduleName.c_str(), source.c_str(), entry.c_str(), "spirv_1_5"); + return createComputeProgram(ctx, spirv, readonlyBuffers, readwriteBuffers); +} + void destroyComputeProgram(VulkanContext& ctx, ComputeProgram& prog) { ctx.device.destroyDescriptorPool(prog.descriptorPool); ctx.device.destroyPipeline(prog.pipeline); diff --git a/TensorFrost/src/Compiler/OperationRegistry.cpp b/TensorFrost/src/Compiler/OperationRegistry.cpp index 4a3ac15f..e3076618 100644 --- a/TensorFrost/src/Compiler/OperationRegistry.cpp +++ b/TensorFrost/src/Compiler/OperationRegistry.cpp @@ -28,12 +28,13 @@ ArgSpec::ArgSpec(std::string io, std::map> types, } } +bool ArgSpec::ValidArgCount(size_t count) const { + return variadic ? count >= in.size() - 1 : count == in.size(); +} + bool ArgSpec::IsValid(std::vector inputs, TFDataFormat output) const { - if (variadic) { - if (in.empty() || inputs.empty()) return false; - } else { - if (inputs.size() != in.size()) return false; - } + bool valid_count = ValidArgCount(inputs.size()); + if (!valid_count) return false; auto name_of = [&](size_t i) -> const char& { return variadic ? in.front() : in[i]; @@ -63,6 +64,9 @@ bool ArgSpec::IsValid(std::vector inputs, TFDataFormat output) con } TFDataFormat ArgSpec::EstimateOutputType(const std::vector &inputs) const { + bool valid_count = ValidArgCount(inputs.size()); + if (!valid_count) return TFUnknown; + auto name_of = [&](size_t i) -> const char& { return variadic ? in.front() : in[i]; }; @@ -145,6 +149,7 @@ vector default_operations = { DEF_OP("const", ("x()"), OpClass::Constant), DEF_OP("copy", ("x(x)"), OpClass::Copy), + DEF_OP("set", ("x(x,x)"), OpClass::Set), DEF_OP("add", ("x(x,x)"), OpClass::Operator, .const_fold = BIN_OP_FOLD(+)), DEF_OP("sub", ("x(x,x)"), OpClass::Operator, diff --git a/TensorFrost/src/Compiler/Value.cpp b/TensorFrost/src/Compiler/Value.cpp index 3a02c1f7..52c17758 100644 --- a/TensorFrost/src/Compiler/Value.cpp +++ b/TensorFrost/src/Compiler/Value.cpp @@ -95,6 +95,12 @@ bool Value::Compare(const Value &other) const { return op->Compare(*other.op); } +void Value::Set(Value value) { + Value set = value_op("set", {*this,value}); + this->op = set.op; + this->out_index = set.out_index; +} + Value Value::operator[](const Values& indices) const { return load_at_index(*this, indices); } diff --git a/examples/debug.py b/examples/debug.py index 0c24adae..de3cdd48 100644 --- a/examples/debug.py +++ b/examples/debug.py @@ -1,2 +1,48 @@ import numpy as np import TensorFrost as tf + +ctx = tf.VulkanContext() + +N = 1024 +aBuf = tf.createBuffer(ctx, N, 4, True) +bBuf = tf.createBuffer(ctx, N, 4, True) +outBuf = tf.createBuffer(ctx, N, 4, False) + +# write inputs via mapped views +a_map = ctx.device.mapMemory(aBuf.memory, 0, aBuf.size) +b_map = ctx.device.mapMemory(bBuf.memory, 0, bBuf.size) +np.frombuffer(a_map, dtype=np.float32)[:] = np.arange(N, dtype=np.float32) +np.frombuffer(b_map, dtype=np.float32)[:] = 2 * np.arange(N, dtype=np.float32) +ctx.device.unmapMemory(aBuf.memory) +ctx.device.unmapMemory(bBuf.memory) + +code = r''' +[[vk::binding(2,0)]] RWStructuredBuffer C; +[[vk::binding(0,0)]] StructuredBuffer A; +[[vk::binding(1,0)]] StructuredBuffer B; + +[shader("compute")] [numthreads(64,1,1)] +void computeMain(uint3 tid: SV_DispatchThreadID) { + C[tid.x] = 2.0f*A[tid.x] + B[tid.x]; +} +''' + +prog = tf.createComputeProgramFromSlang( + ctx, "vecadd", code, "computeMain", + [aBuf, bBuf], [outBuf] +) + +tf.runProgram(ctx, prog, N) + +# read back +out_map = ctx.device.mapMemory(outBuf.memory, 0, outBuf.size) +out = np.frombuffer(out_map, dtype=np.float32).copy() +ctx.device.unmapMemory(outBuf.memory) + +ok = np.allclose(out, 4 * np.arange(N, dtype=np.float32)) +print("Compute result is", "correct" if ok else "incorrect") + +tf.destroyComputeProgram(ctx, prog) +tf.destroyBuffer(ctx, aBuf) +tf.destroyBuffer(ctx, bBuf) +tf.destroyBuffer(ctx, outBuf) \ No newline at end of file From f840a502f980054c5d1eaf9ec7de4d01cb92e6b4 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Mon, 25 Aug 2025 04:17:36 +0200 Subject: [PATCH 20/44] Vulkan, GLFW and frontend Improved Vulkan and basic GLFW window with python bindings --- .run/TensorFrost.run.xml | 2 +- TensorFrost/PybindModule.cpp | 145 +++++--- TensorFrost/include/Backend/Vulkan.h | 79 ++++- TensorFrost/include/Backend/Window.h | 79 +++++ TensorFrost/include/TensorFrost.h | 3 +- TensorFrost/src/Backend/Vulkan.cpp | 496 +++++++++++++++++++-------- TensorFrost/src/Backend/Window.cpp | 122 +++++++ examples/Slang/mandelbrot.py | 37 ++ examples/Slang/mandelbrot.slang | 50 +++ examples/debug.py | 116 ++++--- 10 files changed, 872 insertions(+), 257 deletions(-) create mode 100644 TensorFrost/include/Backend/Window.h create mode 100644 TensorFrost/src/Backend/Window.cpp create mode 100644 examples/Slang/mandelbrot.py create mode 100644 examples/Slang/mandelbrot.slang diff --git a/.run/TensorFrost.run.xml b/.run/TensorFrost.run.xml index ba4e1bbd..61d7af64 100644 --- a/.run/TensorFrost.run.xml +++ b/.run/TensorFrost.run.xml @@ -1,5 +1,5 @@ - + diff --git a/TensorFrost/PybindModule.cpp b/TensorFrost/PybindModule.cpp index a4740e46..35f8468c 100644 --- a/TensorFrost/PybindModule.cpp +++ b/TensorFrost/PybindModule.cpp @@ -24,36 +24,14 @@ namespace TensorFrost { // void ScopeDefinitions(py::module& m, py::class_& py_tensor); // void ModuleDefinitions(py::module& m); -struct PyDevice { - vk::Device* dev{}; - explicit PyDevice(vk::Device* d) : dev(d) {} - - py::memoryview mapMemory(const vk::DeviceMemory& mem, uint64_t offset, uint64_t size, bool readonly=false) { - if (size == 0) throw py::value_error("size==0"); - if (size > static_cast(std::numeric_limits::max())) - throw py::value_error("size too large"); - - void* p = nullptr; - { - py::gil_scoped_release nogil; - p = dev->mapMemory(mem, offset, size); - } - if (!p) throw py::value_error("vkMapMemory returned nullptr"); - - PyObject* raw = PyMemoryView_FromMemory( - reinterpret_cast(p), - static_cast(size), - readonly ? PyBUF_READ : PyBUF_WRITE); - if (!raw) throw py::error_already_set(); - - return py::reinterpret_steal(raw); - } - - void unmapMemory(const vk::DeviceMemory& mem) { - py::gil_scoped_release nogil; - dev->unmapMemory(mem); +static bool is_c_contig(const py::buffer_info& i) { + py::ssize_t stride = i.itemsize; + for (py::ssize_t d = i.ndim - 1; d >= 0; --d) { + if (i.strides[d] != stride) return false; + stride *= i.shape[d]; } -}; + return true; +} PYBIND11_MODULE(TensorFrost, m) { m.doc() = "TensorFrost library"; @@ -206,36 +184,93 @@ PYBIND11_MODULE(TensorFrost, m) { program.Compile(); py::print(program.DebugPrint()); - py::class_(m, "DeviceMemory"); - py::class_(m, "Device") - .def("mapMemory", &PyDevice::mapMemory, - py::arg("memory"), py::arg("offset"), py::arg("size"), py::arg("readonly") = false) - .def("unmapMemory", &PyDevice::unmapMemory, py::arg("memory")); + py::class_(m, "VulkanContext").def(py::init<>()); + + py::class_(m, "Buffer") + .def_property_readonly("size", [](const Buffer& b){ return b.size; }); + + py::class_(m, "ComputeProgram"); + + m.def("createBuffer", &createBuffer, + py::arg("ctx"), py::arg("count"), py::arg("dtypeSize"), py::arg("readOnly"), + py::return_value_policy::move); - py::class_(m, "VulkanContext") - .def(py::init<>()) - .def_property_readonly("device", - [](VulkanContext& c){ return PyDevice(&c.device); }, - py::return_value_policy::reference_internal); + m.def("destroyBuffer", &destroyBuffer, py::arg("ctx"), py::arg("buf")); - py::class_(m, "Buffer") - .def_readonly("memory", &Buffer::memory) - .def_property_readonly("size", [](const Buffer& b){ return b.size; }); + m.def("createComputeProgramFromGLSL", &createComputeProgramFromGLSL, + py::arg("ctx"), py::arg("glsl_source"), py::arg("roCount"), py::arg("rwCount"), + py::return_value_policy::move); - py::class_(m, "ComputeProgram"); + m.def("createComputeProgramFromSlang", &createComputeProgramFromSlang, + py::arg("ctx"), py::arg("moduleName"), py::arg("source"), py::arg("entry"), + py::arg("roCount"), py::arg("rwCount"), + py::return_value_policy::move); + + m.def("destroyComputeProgram", &destroyComputeProgram, py::arg("ctx"), py::arg("prog")); + + m.def("runProgram", &runProgram, + py::arg("ctx"), py::arg("prog"), + py::arg("readonlyBuffers"), py::arg("readwriteBuffers"), + py::arg("numInvocations"), + py::call_guard()); + + // --- numpy I/O --- + m.def("setBufferData", + [](VulkanContext& ctx, Buffer& buf, py::array arr, size_t offset) { + auto info = arr.request(); // GIL held + if (!is_c_contig(info)) throw std::runtime_error("array must be C-contiguous"); + size_t nbytes = static_cast(info.size) * static_cast(info.itemsize); + if (offset + nbytes > buf.size) throw std::out_of_range("write out of range"); + py::gil_scoped_release release; + setBufferData(ctx, buf, info.ptr, nbytes, offset); + }, + py::arg("ctx"), py::arg("buf"), py::arg("array"), py::arg("offset") = 0); + + m.def("getBufferData", + [](VulkanContext& ctx, const Buffer& buf, py::dtype dt, size_t count, size_t offset) { + size_t itemsize = dt.attr("itemsize").cast(); // GIL held + size_t nbytes = count * itemsize; + if (offset + nbytes > buf.size) throw std::out_of_range("read out of range"); + + py::array out(dt, py::array::ShapeContainer{ static_cast(count) }); + auto info = out.request(); // contiguous by default + + { py::gil_scoped_release release; + getBufferData(ctx, buf, info.ptr, nbytes, offset); + } + return out; + }, + py::arg("ctx"), py::arg("buf"), py::arg("dtype"), py::arg("count"), py::arg("offset") = 0); + + m.def("getBufferData_into", + [](VulkanContext& ctx, const Buffer& buf, py::array out, size_t offset) { + auto info = out.request(); // GIL held + if (info.readonly) throw std::runtime_error("output array must be writeable"); + if (!is_c_contig(info)) throw std::runtime_error("output array must be C-contiguous"); + size_t nbytes = static_cast(info.size) * static_cast(info.itemsize); + if (offset + nbytes > buf.size) throw std::out_of_range("read out of range"); + py::gil_scoped_release release; + getBufferData(ctx, buf, info.ptr, nbytes, offset); + }, + py::arg("ctx"), py::arg("buf"), py::arg("out"), py::arg("offset") = 0); + + py::class_(m, "WindowContext") + .def_property_readonly("size", + [](const WindowContext& c){ return py::make_tuple(c.extent.width, c.extent.height); }) + .def_property_readonly("format", + [](const WindowContext& c){ return static_cast(c.format); }); + + m.def("createWindow", + static_cast(&createWindow), + py::arg("ctx"), py::arg("width"), py::arg("height"), py::arg("title"), + py::return_value_policy::move, + py::keep_alive<0,1>(), // keep ctx alive as long as WindowContext lives + py::call_guard()); - m.def("createBuffer", &createBuffer, py::arg("ctx"), py::arg("count"), py::arg("dtypeSize"), py::arg("readOnly"), - py::return_value_policy::move); - m.def("destroyBuffer", &destroyBuffer, py::arg("ctx"), py::arg("buf")); - m.def("createComputeProgramFromGLSL", &createComputeProgramFromGLSL, - py::arg("ctx"), py::arg("glsl_source"), py::arg("readonlyBuffers"), py::arg("readwriteBuffers"), - py::return_value_policy::move, py::keep_alive<0,3>(), py::keep_alive<0,4>()); - m.def("createComputeProgramFromSlang", &createComputeProgramFromSlang, - py::arg("ctx"), py::arg("moduleName"), py::arg("source"), py::arg("entry"), - py::arg("readonlyBuffers"), py::arg("readwriteBuffers"), - py::return_value_policy::move, py::keep_alive<0,5>(), py::keep_alive<0,6>()); - m.def("destroyComputeProgram", &destroyComputeProgram, py::arg("ctx"), py::arg("prog")); - m.def("runProgram", &runProgram, py::arg("ctx"), py::arg("prog"), py::arg("numInvocations"), + m.def("windowOpen", &windowOpen, py::arg("ctx")); + m.def("drawBuffer", + static_cast(&drawBuffer), + py::arg("ctx"), py::arg("buffer"), py::arg("width"), py::arg("height"), py::arg("offset") = 0, py::call_guard()); // VulkanContext ctx; // diff --git a/TensorFrost/include/Backend/Vulkan.h b/TensorFrost/include/Backend/Vulkan.h index 2bff292a..8c222b0d 100644 --- a/TensorFrost/include/Backend/Vulkan.h +++ b/TensorFrost/include/Backend/Vulkan.h @@ -3,6 +3,8 @@ #include #include #include +#include +#include struct Buffer { vk::Buffer buffer; @@ -15,41 +17,80 @@ struct ComputeProgram { vk::DescriptorSetLayout descriptorLayout; vk::PipelineLayout pipelineLayout; vk::Pipeline pipeline; - vk::DescriptorPool descriptorPool; - vk::DescriptorSet descriptorSet; + uint32_t numRO = 0, numRW = 0; +}; + +struct ComputeBindings { + vk::DescriptorSet set{}; +}; + +// --- cache key + hash (as before) +struct DSKey { + vk::DescriptorSetLayout layout{}; + std::vector bufs; + std::vector sizes; + bool operator==(const DSKey& o) const { + return layout==o.layout && bufs==o.bufs && sizes==o.sizes; + } +}; +struct DSKeyHash { + size_t operator()(DSKey const& k) const noexcept { + auto mix = [](size_t& s, uint64_t v){ s ^= std::hash{}(v) + 0x9e3779b97f4a7c15ULL + (s<<6) + (s>>2); }; + size_t seed = 0; + mix(seed, (uint64_t)VkDescriptorSetLayout(k.layout)); + for (auto b : k.bufs) mix(seed, (uint64_t)b); + for (auto sz: k.sizes) mix(seed, (uint64_t)sz); + return seed; + } +}; + +// --- cached entry +struct CachedDS { + vk::DescriptorSet set{}; + std::vector buffers; // for invalidation + uint64_t lastUseTick = 0; + uint64_t useCount = 0; +}; + + +struct CachedBuf { + Buffer buf; + uint64_t lastUse=0; + uint64_t useCount=0; }; // Holds instance, physical device, logical device, queue and command pool. -class VulkanContext { -public: +struct VulkanContext { vk::Instance instance; vk::PhysicalDevice physicalDevice; + uint32_t queueFamilyIndex = UINT32_MAX; // compute + uint32_t graphicsFamilyIndex = UINT32_MAX; // present candidate vk::Device device; vk::Queue computeQueue; - uint32_t queueFamilyIndex; + vk::Queue presentQueue; vk::CommandPool commandPool; + vk::DescriptorPool descriptorPool; + std::unordered_map dsCache; + size_t dsCacheCapacity = 256; // must be ≤ pool maxSets + uint64_t dsUseTick = 1; + + std::unordered_map> bufferCache; // key: exact byte size + size_t bufferCacheCapacity = 128; // total entries across all sizes + uint64_t bufferUseTick = 1; + VulkanContext(); ~VulkanContext(); }; -// Creates a host‑visible storage buffer for read‑only or read‑write access. Buffer createBuffer(VulkanContext& ctx, size_t count, size_t dtypeSize, bool readOnly); - -// Releases the buffer and its memory. void destroyBuffer(VulkanContext& ctx, Buffer& buf); +void setBufferData(VulkanContext& ctx, Buffer& buf, const void* src, size_t bytes, size_t offset = 0); +void getBufferData(VulkanContext& ctx, const Buffer& buf, void* dst, size_t bytes, size_t offset = 0); -// Compiles a GLSL compute shader and builds a compute pipeline with descriptors. -ComputeProgram createComputeProgramFromGLSL(VulkanContext& ctx, - const std::string& glsl_source, - const std::vector& readonlyBuffers, - const std::vector& readwriteBuffers); - +ComputeProgram createComputeProgramFromGLSL(VulkanContext& ctx, const std::string& glsl, uint32_t roCount, uint32_t rwCount); ComputeProgram createComputeProgramFromSlang(VulkanContext& ctx, const std::string& moduleName, - const std::string& source, const std::string& entry, const std::vector& readonlyBuffers, const std::vector& readwriteBuffers); - -// Destroys the compute program and associated resources. + const std::string& source, const std::string& entry, uint32_t roCount, uint32_t rwCount); void destroyComputeProgram(VulkanContext& ctx, ComputeProgram& prog); -// Dispatches a compute program with the given number of invocations. -void runProgram(VulkanContext& ctx, ComputeProgram& prog, uint32_t numInvocations); \ No newline at end of file +void runProgram(VulkanContext& ctx, const ComputeProgram& prog, const std::vector& readonlyBuffers, const std::vector& readwriteBuffers, uint32_t n); \ No newline at end of file diff --git a/TensorFrost/include/Backend/Window.h b/TensorFrost/include/Backend/Window.h new file mode 100644 index 00000000..22d26a4e --- /dev/null +++ b/TensorFrost/include/Backend/Window.h @@ -0,0 +1,79 @@ +#pragma once +#include "Vulkan.h" +#include +#include +#include + +struct WindowContext { + GLFWwindow* wnd{}; + vk::Instance instance; + vk::PhysicalDevice phys; + vk::Device device; + uint32_t presentFam{}; + vk::Queue queue; + + vk::SurfaceKHR surface; + vk::SwapchainKHR swapchain; + std::vector images; + vk::Format format{}; + vk::Extent2D extent{}; + vk::CommandPool pool; + vk::CommandBuffer cmd; + vk::Semaphore semImage, semDone; + vk::Fence fence; + + WindowContext() = default; + WindowContext(const WindowContext&) = delete; + WindowContext& operator=(const WindowContext&) = delete; + + WindowContext(WindowContext&& o) noexcept { moveFrom(std::move(o)); } + WindowContext& operator=(WindowContext&& o) noexcept { + if (this != &o) { cleanup(); moveFrom(std::move(o)); } + return *this; + } + + ~WindowContext() { cleanup(); } + +private: + void moveFrom(WindowContext&& o) { + wnd=o.wnd; o.wnd=nullptr; + instance=o.instance; o.instance=nullptr; + phys=o.phys; o.phys=nullptr; + device=o.device; o.device=nullptr; + presentFam=o.presentFam; o.presentFam=0; + queue=o.queue; o.queue=nullptr; + surface=o.surface; o.surface=nullptr; + swapchain=o.swapchain; o.swapchain=nullptr; + images=std::move(o.images); + format=o.format; o.format=vk::Format{}; + extent=o.extent; o.extent=vk::Extent2D{}; + pool=o.pool; o.pool=nullptr; + cmd=o.cmd; o.cmd=nullptr; + semImage=o.semImage; o.semImage=nullptr; + semDone=o.semDone; o.semDone=nullptr; + fence=o.fence; o.fence=nullptr; + } + + void cleanup() { + if (!wnd && !device) return; // already moved/clean + // don’t terminate GLFW here; only destroy this window + if (device) { + (void)device.waitIdle(); + if (fence) device.destroyFence(fence), fence=nullptr; + if (semDone) device.destroySemaphore(semDone), semDone=nullptr; + if (semImage) device.destroySemaphore(semImage), semImage=nullptr; + if (cmd) device.freeCommandBuffers(pool, cmd), cmd=nullptr; + if (pool) device.destroyCommandPool(pool), pool=nullptr; + if (swapchain) device.destroySwapchainKHR(swapchain), swapchain=nullptr; + } + if (surface) instance.destroySurfaceKHR(surface), surface=nullptr; + if (wnd) { glfwDestroyWindow(wnd); wnd=nullptr; } + // leave GLFW alive; app can call glfwTerminate() once at shutdown if it wants + device=nullptr; instance=nullptr; queue=nullptr; + } +}; + +WindowContext createWindow(VulkanContext& vctx, int width, int height, const char* title); +bool windowOpen(const WindowContext& ctx); +void drawBuffer(WindowContext& ctx, vk::Buffer src, uint32_t width, uint32_t height, vk::DeviceSize offset = 0); +void drawBuffer(WindowContext& ctx, const Buffer& b, uint32_t w, uint32_t h, size_t offset = 0); diff --git a/TensorFrost/include/TensorFrost.h b/TensorFrost/include/TensorFrost.h index c1165ff7..25e5c84b 100644 --- a/TensorFrost/include/TensorFrost.h +++ b/TensorFrost/include/TensorFrost.h @@ -8,4 +8,5 @@ #include "Compiler/Value.h" #include "Compiler/Printer.h" #include "Compiler/TFProgram.h" -#include "Backend/Vulkan.h" \ No newline at end of file +#include "Backend/Vulkan.h" +#include "Backend/Window.h" \ No newline at end of file diff --git a/TensorFrost/src/Backend/Vulkan.cpp b/TensorFrost/src/Backend/Vulkan.cpp index 78a1d46f..0ab38cbb 100644 --- a/TensorFrost/src/Backend/Vulkan.cpp +++ b/TensorFrost/src/Backend/Vulkan.cpp @@ -5,6 +5,318 @@ VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE #include #include +VulkanContext::VulkanContext() { + if (!glfwInit()) throw std::runtime_error("GLFW init failed"); + if (!glfwVulkanSupported()) + throw std::runtime_error("GLFW: Vulkan loader not found. Install a Vulkan-capable GPU driver or the Vulkan SDK."); + + VULKAN_HPP_DEFAULT_DISPATCHER.init(vkGetInstanceProcAddr); + + uint32_t extCount = 0; + const char** extNames = glfwGetRequiredInstanceExtensions(&extCount); + if (!extNames || extCount == 0) throw std::runtime_error("GLFW: Vulkan not supported"); + + vk::ApplicationInfo appInfo("TensorFrost", 1, nullptr, 0, VK_API_VERSION_1_2); + vk::InstanceCreateInfo instCreate({}, &appInfo, 0, nullptr, extCount, extNames); + instance = vk::createInstance(instCreate); + VULKAN_HPP_DEFAULT_DISPATCHER.init(instance); + + // pick device + queue families (compute + graphics) + auto devices = instance.enumeratePhysicalDevices(); + if (devices.empty()) throw std::runtime_error("No physical devices"); + + for (auto& pd : devices) { + auto q = pd.getQueueFamilyProperties(); + int compute = -1, graphics = -1; + for (uint32_t i = 0; i < q.size(); ++i) { + auto f = q[i].queueFlags; + if (compute < 0 && (f & vk::QueueFlagBits::eCompute)) compute = int(i); + if (graphics < 0 && (f & vk::QueueFlagBits::eGraphics)) graphics = int(i); + } + if (compute >= 0 && graphics >= 0) { + physicalDevice = pd; + queueFamilyIndex = uint32_t(compute); + graphicsFamilyIndex = uint32_t(graphics); + break; + } + } + if (!physicalDevice) throw std::runtime_error("No suitable device"); + + // device with both queues + swapchain + float prio = 1.0f; + std::vector queues; + queues.emplace_back(vk::DeviceQueueCreateInfo({}, queueFamilyIndex, 1, &prio)); + if (graphicsFamilyIndex != queueFamilyIndex) + queues.emplace_back(vk::DeviceQueueCreateInfo({}, graphicsFamilyIndex, 1, &prio)); + + // enable VK_KHR_swapchain (required for window/swapchain) + const char* devExts[] = { VK_KHR_SWAPCHAIN_EXTENSION_NAME }; + vk::DeviceCreateInfo devCreate({}, (uint32_t)queues.size(), queues.data(), + 0, nullptr, 1, devExts); + device = physicalDevice.createDevice(devCreate); + VULKAN_HPP_DEFAULT_DISPATCHER.init(device); + + computeQueue = device.getQueue(queueFamilyIndex, 0); + presentQueue = device.getQueue(graphicsFamilyIndex, 0); + + vk::CommandPoolCreateInfo poolInfo({}, queueFamilyIndex); + commandPool = device.createCommandPool(poolInfo); + + vk::DescriptorPoolSize sz(vk::DescriptorType::eStorageBuffer, 1024); + vk::DescriptorPoolCreateInfo dp(vk::DescriptorPoolCreateFlagBits::eFreeDescriptorSet, 256, 1, &sz); + descriptorPool = device.createDescriptorPool(dp); + dsCacheCapacity = 256; +} + +static void evictSome(VulkanContext& ctx, size_t n) { + if (ctx.dsCache.empty() || n == 0) return; + std::vector::iterator> items; + items.reserve(ctx.dsCache.size()); + for (auto it = ctx.dsCache.begin(); it != ctx.dsCache.end(); ++it) items.push_back(it); + std::stable_sort(items.begin(), items.end(), + [](auto a, auto b){ return a->second.lastUseTick < b->second.lastUseTick; }); // LRU first + n = std::min(n, items.size()); + for (size_t i = 0; i < n; ++i) { + auto it = items[i]; + if (it->second.set) ctx.device.freeDescriptorSets(ctx.descriptorPool, 1, &it->second.set); + ctx.dsCache.erase(it); + } +} + +static void evictToCapacity(VulkanContext& ctx) { + if (ctx.dsCache.size() > ctx.dsCacheCapacity) + evictSome(ctx, ctx.dsCache.size() - ctx.dsCacheCapacity); +} + +static void invalidateDescriptorCacheForBuffer(VulkanContext& ctx, VkBuffer buf) { + std::vector dead; + for (auto it = ctx.dsCache.begin(); it != ctx.dsCache.end(); ++it) + if (std::find(it->second.buffers.begin(), it->second.buffers.end(), buf) != it->second.buffers.end()) + dead.push_back(it); + for (auto it : dead) { + if (it->second.set) ctx.device.freeDescriptorSets(ctx.descriptorPool, 1, &it->second.set); + ctx.dsCache.erase(it); + } +} + +static void invalidateDescriptorCacheForLayout(VulkanContext& ctx, vk::DescriptorSetLayout layout) { + std::vector dead; + for (auto it = ctx.dsCache.begin(); it != ctx.dsCache.end(); ++it) + if (it->first.layout == layout) dead.push_back(it); + for (auto it : dead) { + if (it->second.set) ctx.device.freeDescriptorSets(ctx.descriptorPool, 1, &it->second.set); + ctx.dsCache.erase(it); + } +} + +// optional knobs +void setDescriptorCacheCapacity(VulkanContext& ctx, size_t cap) { + ctx.dsCacheCapacity = cap; + evictToCapacity(ctx); +} +void clearDescriptorCache(VulkanContext& ctx) { + evictSome(ctx, ctx.dsCache.size()); +} + +// --- cached descriptor set retrieval +static vk::DescriptorSet getOrCreateSet(VulkanContext& ctx, const ComputeProgram& prog, + const std::vector& ro, const std::vector& rw) { + if (ro.size()!=prog.numRO || rw.size()!=prog.numRW) throw std::runtime_error("buffer count != program layout"); + + DSKey key; key.layout = prog.descriptorLayout; + key.bufs.reserve(ro.size()+rw.size()); + key.sizes.reserve(ro.size()+rw.size()); + for (auto* b : ro) { key.bufs.push_back(b->buffer); key.sizes.push_back(b->size); } + for (auto* b : rw) { key.bufs.push_back(b->buffer); key.sizes.push_back(b->size); } + + if (auto it = ctx.dsCache.find(key); it != ctx.dsCache.end()) { + it->second.lastUseTick = ++ctx.dsUseTick; + it->second.useCount++; + return it->second.set; + } + + evictToCapacity(ctx); + + vk::DescriptorSet set{}; + for (int attempt = 0; attempt < 3; ++attempt) { + try { + vk::DescriptorSetAllocateInfo ai(ctx.descriptorPool, 1, &prog.descriptorLayout); + set = ctx.device.allocateDescriptorSets(ai)[0]; + break; + } catch (const vk::SystemError& e) { + auto r = static_cast(e.code().value()); + if (r == vk::Result::eErrorOutOfPoolMemory || r == vk::Result::eErrorFragmentedPool) { + evictSome(ctx, std::max(1, ctx.dsCache.size()/4)); + } else { + throw; + } + } + } + if (!set) throw std::runtime_error("Descriptor set allocation failed"); + + std::vector infos; infos.reserve(key.bufs.size()); + for (auto* b : ro) infos.emplace_back(b->buffer, 0, b->size); + for (auto* b : rw) infos.emplace_back(b->buffer, 0, b->size); + + std::vector writes; writes.reserve(infos.size()); + for (uint32_t i=0;i nodes; nodes.reserve(bufferCacheCount(ctx)); + + for (auto& kv : ctx.bufferCache) + for (auto& e : kv.second) + nodes.push_back(Node{ kv.first, e.buf.buffer, e.lastUse }); + + std::stable_sort(nodes.begin(), nodes.end(), + [](const Node& a, const Node& b){ return a.last < b.last; }); + + if (n > nodes.size()) n = nodes.size(); + + for (size_t i = 0; i < n; ++i) { + auto itMap = ctx.bufferCache.find(nodes[i].key); + if (itMap == ctx.bufferCache.end()) continue; + + auto& vec = itMap->second; + + auto it = std::find_if(vec.begin(), vec.end(), + [&](const CachedBuf& e){ + return e.buf.buffer == nodes[i].buf; + }); + if (it == vec.end()) continue; // already evicted by an earlier iteration + + Buffer b = it->buf; + vec.erase(it); + if (vec.empty()) ctx.bufferCache.erase(itMap); + + if (b.buffer) ctx.device.destroyBuffer(b.buffer); + if (b.memory) ctx.device.freeMemory(b.memory); + } +} + +static void evictBuffersToCapacity(VulkanContext& ctx) { + size_t cnt = bufferCacheCount(ctx); + if (cnt > ctx.bufferCacheCapacity) evictBuffers(ctx, cnt - ctx.bufferCacheCapacity); +} + +void clearBufferCache(VulkanContext& ctx) { evictBuffers(ctx, bufferCacheCount(ctx)); } + +void setBufferCacheCapacity(VulkanContext& ctx, size_t cap) { ctx.bufferCacheCapacity = cap; evictBuffersToCapacity(ctx); } + +static bool takeBufferFromCache(VulkanContext& ctx, size_t bytes, Buffer& out) { + auto it = ctx.bufferCache.find(bytes); + if (it == ctx.bufferCache.end() || it->second.empty()) return false; + auto e = std::move(it->second.back()); + it->second.pop_back(); + if (it->second.empty()) ctx.bufferCache.erase(it); + out = e.buf; // handles copied; cache keeps no reference + return true; +} + +// create a storage buffer +Buffer createBuffer(VulkanContext& ctx, size_t count, size_t dtypeSize, bool readOnly) { + Buffer buf{}; + buf.size = count * dtypeSize; + + if (takeBufferFromCache(ctx, buf.size, buf)) return buf; + + vk::BufferCreateInfo bci({}, buf.size, + vk::BufferUsageFlagBits::eStorageBuffer | + vk::BufferUsageFlagBits::eTransferSrc | + vk::BufferUsageFlagBits::eTransferDst); + buf.buffer = ctx.device.createBuffer(bci); + auto memReq = ctx.device.getBufferMemoryRequirements(buf.buffer); + + auto memProps = ctx.physicalDevice.getMemoryProperties(); + uint32_t memTypeIndex = 0; + for (uint32_t i = 0; i < memProps.memoryTypeCount; i++) { + if ((memReq.memoryTypeBits & (1u< buf.size) throw std::out_of_range("write out of range"); + auto atom = ctx.physicalDevice.getProperties().limits.nonCoherentAtomSize; + vk::DeviceSize mapOff = offset - (offset % atom); + vk::DeviceSize mapEnd = ((offset + bytes + atom - 1) / atom) * atom; + vk::DeviceSize mapSz = mapEnd - mapOff; + + void* p = ctx.device.mapMemory(buf.memory, mapOff, mapSz); + std::memcpy(static_cast(p) + (offset - mapOff), src, bytes); + vk::MappedMemoryRange rng(buf.memory, mapOff, mapSz); + ctx.device.flushMappedMemoryRanges(rng); // needed if memory not coherent + ctx.device.unmapMemory(buf.memory); +} + +void getBufferData(VulkanContext& ctx, const Buffer& buf, void* dst, size_t bytes, size_t offset) { + if (offset + bytes > buf.size) throw std::out_of_range("read out of range"); + auto atom = ctx.physicalDevice.getProperties().limits.nonCoherentAtomSize; + vk::DeviceSize mapOff = offset - (offset % atom); + vk::DeviceSize mapEnd = ((offset + bytes + atom - 1) / atom) * atom; + vk::DeviceSize mapSz = mapEnd - mapOff; + + void* p = ctx.device.mapMemory(buf.memory, mapOff, mapSz); + vk::MappedMemoryRange rng(buf.memory, mapOff, mapSz); + ctx.device.invalidateMappedMemoryRanges(rng); // needed if memory not coherent + std::memcpy(dst, static_cast(p) + (offset - mapOff), bytes); + ctx.device.unmapMemory(buf.memory); +} + +VulkanContext::~VulkanContext() { + // free cached sets first + for (auto it = dsCache.begin(); it != dsCache.end(); ++it) { + if (it->second.set) device.freeDescriptorSets(descriptorPool, 1, &it->second.set); + } + dsCache.clear(); + clearBufferCache(*this); + device.destroyDescriptorPool(descriptorPool); + device.destroyCommandPool(commandPool); + device.destroy(); + instance.destroy(); +} + // compile GLSL to SPIR-V at runtime static std::vector compileGLSLToSpirv(const std::string& source) { shaderc::Compiler compiler; @@ -77,196 +389,96 @@ std::vector compileSlangToSpirv(const char* moduleName, return out; } +ComputeBindings createBindings(VulkanContext& ctx, const ComputeProgram& prog, + const std::vector& readonlyBuffers, + const std::vector& readwriteBuffers) { + if (readonlyBuffers.size() != prog.numRO || readwriteBuffers.size() != prog.numRW) + throw std::runtime_error("buffer count != program layout"); -// VulkanContext constructor sets up instance, selects a compute device and queue, and creates a command pool. -VulkanContext::VulkanContext() { - VULKAN_HPP_DEFAULT_DISPATCHER.init(vkGetInstanceProcAddr); // required before vk::createInstance - - // 1) instance - vk::ApplicationInfo appInfo("ComputeFramework", 1, nullptr, 0, VK_API_VERSION_1_1); - vk::InstanceCreateInfo instCreate({}, &appInfo); - instance = vk::createInstance(instCreate); - VULKAN_HPP_DEFAULT_DISPATCHER.init(instance); // load instance-level funcs - - // 2) pick physical device + compute queue family - auto devices = instance.enumeratePhysicalDevices(); - if (devices.empty()) throw std::runtime_error("No physical devices"); - for (auto& pd : devices) { - auto q = pd.getQueueFamilyProperties(); - for (uint32_t i = 0; i < q.size(); ++i) { - if ( (q[i].queueFlags & vk::QueueFlagBits::eCompute) != vk::QueueFlags{} ) { - physicalDevice = pd; - queueFamilyIndex = i; - break; - } - } - if (physicalDevice) break; - } - if (!physicalDevice) throw std::runtime_error("No compute queue"); - - // 3) device + queue - float prio = 1.0f; - vk::DeviceQueueCreateInfo qci({}, queueFamilyIndex, 1, &prio); - vk::DeviceCreateInfo devCreate({}, qci); - device = physicalDevice.createDevice(devCreate); - VULKAN_HPP_DEFAULT_DISPATCHER.init(device); // load device-level funcs - - computeQueue = device.getQueue(queueFamilyIndex, 0); - - // 4) command pool - vk::CommandPoolCreateInfo poolInfo({}, queueFamilyIndex); - commandPool = device.createCommandPool(poolInfo); -} - -// VulkanContext destructor cleans up the command pool, device and instance. -VulkanContext::~VulkanContext() { - device.destroyCommandPool(commandPool); - device.destroy(); - instance.destroy(); -} + vk::DescriptorSetAllocateInfo ai(ctx.descriptorPool, 1, &prog.descriptorLayout); + ComputeBindings b{}; + b.set = ctx.device.allocateDescriptorSets(ai)[0]; -// create a storage buffer -Buffer createBuffer(VulkanContext& ctx, size_t count, size_t dtypeSize, bool readOnly) { - Buffer buf; - buf.size = count * dtypeSize; - vk::BufferCreateInfo bci({}, buf.size, - vk::BufferUsageFlagBits::eStorageBuffer | - vk::BufferUsageFlagBits::eTransferSrc | - vk::BufferUsageFlagBits::eTransferDst); - buf.buffer = ctx.device.createBuffer(bci); - auto memReq = ctx.device.getBufferMemoryRequirements(buf.buffer); + std::vector infos; + infos.reserve(prog.numRO + prog.numRW); + for (auto* x : readonlyBuffers) infos.emplace_back(x->buffer, 0, x->size); + for (auto* x : readwriteBuffers) infos.emplace_back(x->buffer, 0, x->size); - auto memProps = ctx.physicalDevice.getMemoryProperties(); - uint32_t memTypeIndex = 0; - for (uint32_t i = 0; i < memProps.memoryTypeCount; i++) { - bool allowed = memReq.memoryTypeBits & (1u << i); - auto typeBits = memReq.memoryTypeBits; - auto flags = memProps.memoryTypes[i].propertyFlags; - - bool ok = (typeBits & (1u << i)) != 0; - bool hostVis = (flags & vk::MemoryPropertyFlagBits::eHostVisible) != vk::MemoryPropertyFlags{}; - if (allowed && hostVis) { - memTypeIndex = i; - break; - } - } - vk::MemoryAllocateInfo allocInfo(memReq.size, memTypeIndex); - buf.memory = ctx.device.allocateMemory(allocInfo); - ctx.device.bindBufferMemory(buf.buffer, buf.memory, 0); - return buf; -} + std::vector writes; + writes.reserve(infos.size()); + for (uint32_t i = 0; i < infos.size(); ++i) + writes.emplace_back(b.set, i, 0, 1, vk::DescriptorType::eStorageBuffer, nullptr, &infos[i]); -void destroyBuffer(VulkanContext& ctx, Buffer& buf) { - ctx.device.destroyBuffer(buf.buffer); - ctx.device.freeMemory(buf.memory); - buf.buffer = nullptr; - buf.memory = nullptr; + ctx.device.updateDescriptorSets(writes, {}); + return b; } -// internal helper to build a compute program from SPIR-V static ComputeProgram createComputeProgram(VulkanContext& ctx, const std::vector& spirv, - const std::vector& readonlyBuffers, - const std::vector& readwriteBuffers) { + uint32_t roCount, uint32_t rwCount) { ComputeProgram prog; + prog.numRO = roCount; prog.numRW = rwCount; + vk::ShaderModuleCreateInfo smci({}, spirv.size() * sizeof(uint32_t), spirv.data()); prog.shaderModule = ctx.device.createShaderModule(smci); std::vector bindings; - uint32_t binding = 0; - for (size_t i = 0; i < readonlyBuffers.size(); i++) { - bindings.emplace_back(binding++, vk::DescriptorType::eStorageBuffer, 1, - vk::ShaderStageFlagBits::eCompute); - } - for (size_t i = 0; i < readwriteBuffers.size(); i++) { - bindings.emplace_back(binding++, vk::DescriptorType::eStorageBuffer, 1, - vk::ShaderStageFlagBits::eCompute); - } + bindings.reserve(roCount + rwCount); + for (uint32_t b = 0; b < roCount + rwCount; ++b) + bindings.emplace_back(b, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute); + vk::DescriptorSetLayoutCreateInfo dsInfo({}, bindings.size(), bindings.data()); prog.descriptorLayout = ctx.device.createDescriptorSetLayout(dsInfo); vk::PipelineLayoutCreateInfo plInfo({}, 1, &prog.descriptorLayout); prog.pipelineLayout = ctx.device.createPipelineLayout(plInfo); - vk::PipelineShaderStageCreateInfo stageInfo({}, vk::ShaderStageFlagBits::eCompute, - prog.shaderModule, "main"); + vk::PipelineShaderStageCreateInfo stageInfo({}, vk::ShaderStageFlagBits::eCompute, prog.shaderModule, "main"); vk::ComputePipelineCreateInfo cpInfo({}, stageInfo, prog.pipelineLayout); prog.pipeline = ctx.device.createComputePipeline({}, cpInfo).value; - vk::DescriptorPoolSize poolSize(vk::DescriptorType::eStorageBuffer, - readonlyBuffers.size() + readwriteBuffers.size()); - vk::DescriptorPoolCreateInfo poolInfo({}, 1, 1, &poolSize); - prog.descriptorPool = ctx.device.createDescriptorPool(poolInfo); - vk::DescriptorSetAllocateInfo allocInfo(prog.descriptorPool, 1, &prog.descriptorLayout); - prog.descriptorSet = ctx.device.allocateDescriptorSets(allocInfo)[0]; - - std::vector bufferInfos; - bufferInfos.reserve(readonlyBuffers.size() + readwriteBuffers.size()); - for (auto b : readonlyBuffers) { - bufferInfos.push_back(vk::DescriptorBufferInfo(b->buffer, 0, b->size)); - } - for (auto b : readwriteBuffers) { - bufferInfos.push_back(vk::DescriptorBufferInfo(b->buffer, 0, b->size)); - } - std::vector writes; - for (uint32_t i = 0; i < bufferInfos.size(); i++) { - vk::WriteDescriptorSet w(prog.descriptorSet, i, 0, 1, - vk::DescriptorType::eStorageBuffer, nullptr, - &bufferInfos[i]); - writes.push_back(w); - } - ctx.device.updateDescriptorSets(writes, {}); return prog; } -// public wrapper that compiles GLSL and builds the program -ComputeProgram createComputeProgramFromGLSL(VulkanContext& ctx, - const std::string& glsl_source, - const std::vector& readonlyBuffers, - const std::vector& readwriteBuffers) { - - auto spirv = compileGLSLToSpirv(glsl_source); - return createComputeProgram(ctx, spirv, readonlyBuffers, readwriteBuffers); +ComputeProgram createComputeProgramFromGLSL(VulkanContext& ctx, const std::string& glsl, uint32_t roCount, uint32_t rwCount) { + auto spirv = compileGLSLToSpirv(glsl); + return createComputeProgram(ctx, spirv, roCount, rwCount); } - ComputeProgram createComputeProgramFromSlang(VulkanContext& ctx, const std::string& moduleName, - const std::string& source, const std::string& entry, const std::vector& readonlyBuffers, const std::vector& readwriteBuffers) { + const std::string& source, const std::string& entry, uint32_t roCount, uint32_t rwCount) { auto spirv = compileSlangToSpirv(moduleName.c_str(), source.c_str(), entry.c_str(), "spirv_1_5"); - return createComputeProgram(ctx, spirv, readonlyBuffers, readwriteBuffers); + return createComputeProgram(ctx, spirv, roCount, rwCount); } void destroyComputeProgram(VulkanContext& ctx, ComputeProgram& prog) { - ctx.device.destroyDescriptorPool(prog.descriptorPool); + invalidateDescriptorCacheForLayout(ctx, prog.descriptorLayout); ctx.device.destroyPipeline(prog.pipeline); ctx.device.destroyPipelineLayout(prog.pipelineLayout); ctx.device.destroyDescriptorSetLayout(prog.descriptorLayout); ctx.device.destroyShaderModule(prog.shaderModule); - prog.pipeline = nullptr; - prog.pipelineLayout = nullptr; - prog.descriptorLayout = nullptr; - prog.descriptorPool = nullptr; - prog.shaderModule = nullptr; + prog = {}; } -// dispatch compute commands -void runProgram(VulkanContext& ctx, ComputeProgram& prog, uint32_t n) { +void runProgram(VulkanContext& ctx, const ComputeProgram& prog, + const std::vector& readonlyBuffers, + const std::vector& readwriteBuffers, + uint32_t n) { + auto set = getOrCreateSet(ctx, prog, readonlyBuffers, readwriteBuffers); + vk::CommandBufferAllocateInfo ai(ctx.commandPool, vk::CommandBufferLevel::ePrimary, 1); auto cmd = ctx.device.allocateCommandBuffers(ai)[0]; cmd.begin(vk::CommandBufferBeginInfo{}); cmd.bindPipeline(vk::PipelineBindPoint::eCompute, prog.pipeline); - cmd.bindDescriptorSets(vk::PipelineBindPoint::eCompute, prog.pipelineLayout, 0, prog.descriptorSet, {}); + cmd.bindDescriptorSets(vk::PipelineBindPoint::eCompute, prog.pipelineLayout, 0, set, {}); uint32_t gs = 64, groups = (n + gs - 1) / gs; cmd.dispatch(groups, 1, 1); cmd.end(); vk::Fence fence = ctx.device.createFence({}); ctx.computeQueue.submit(vk::SubmitInfo(0, nullptr, 0, 1, &cmd), fence); - [[maybe_unused]] auto rWait = ctx.device.waitForFences(fence, VK_TRUE, UINT64_MAX); - // if you compile with VULKAN_HPP_NO_EXCEPTIONS=1, you can check: - // if (rWait != vk::Result::eSuccess) throw std::runtime_error("waitForFences failed"); - ctx.device.destroyFence(fence); ctx.device.freeCommandBuffers(ctx.commandPool, cmd); -} \ No newline at end of file +} + diff --git a/TensorFrost/src/Backend/Window.cpp b/TensorFrost/src/Backend/Window.cpp new file mode 100644 index 00000000..2fc4ab2e --- /dev/null +++ b/TensorFrost/src/Backend/Window.cpp @@ -0,0 +1,122 @@ +#include "Backend/Vulkan.h" +#include "Backend/Window.h" + +WindowContext createWindow(VulkanContext& vctx, int width, int height, const char* title) { + if (!glfwInit()) throw std::runtime_error("glfwInit"); + glfwWindowHint(GLFW_CLIENT_API, GLFW_NO_API); + + WindowContext ctx{}; + ctx.wnd = glfwCreateWindow(width, height, title, nullptr, nullptr); + if (!ctx.wnd) throw std::runtime_error("glfwCreateWindow"); + + ctx.instance = vctx.instance; + ctx.phys = vctx.physicalDevice; + ctx.device = vctx.device; + + // create surface on the shared instance + VkSurfaceKHR raw{}; + if (glfwCreateWindowSurface(ctx.instance, ctx.wnd, nullptr, &raw) != VK_SUCCESS) + throw std::runtime_error("glfw surface"); + ctx.surface = raw; + + // pick a present-capable family; prefer the graphics family we provisioned + if (ctx.phys.getSurfaceSupportKHR(vctx.graphicsFamilyIndex, ctx.surface)) { + ctx.presentFam = vctx.graphicsFamilyIndex; + ctx.queue = vctx.presentQueue; + } else if (ctx.phys.getSurfaceSupportKHR(vctx.queueFamilyIndex, ctx.surface)) { + ctx.presentFam = vctx.queueFamilyIndex; + ctx.queue = vctx.computeQueue; // uncommon but legal if it supports present + } else { + throw std::runtime_error("device queues do not support presenting to this surface"); + } + + auto caps = ctx.phys.getSurfaceCapabilitiesKHR(ctx.surface); + auto fmts = ctx.phys.getSurfaceFormatsKHR(ctx.surface); + vk::SurfaceFormatKHR sf = fmts.empty() + ? vk::SurfaceFormatKHR{vk::Format::eB8G8R8A8Unorm, vk::ColorSpaceKHR::eSrgbNonlinear} + : fmts[0]; + if (!(caps.supportedUsageFlags & vk::ImageUsageFlagBits::eTransferDst)) + throw std::runtime_error("swapchain missing TRANSFER_DST"); + + int fbw, fbh; glfwGetFramebufferSize(ctx.wnd, &fbw, &fbh); + ctx.extent = vk::Extent2D{ + uint32_t(std::clamp(fbw, int(caps.minImageExtent.width), int(caps.maxImageExtent.width))), + uint32_t(std::clamp(fbh, int(caps.minImageExtent.height), int(caps.maxImageExtent.height))) + }; + ctx.format = sf.format; + + uint32_t imageCount = std::max(caps.minImageCount, 2u); + if (caps.maxImageCount) imageCount = std::min(imageCount, caps.maxImageCount); + + ctx.swapchain = ctx.device.createSwapchainKHR({ + {}, ctx.surface, imageCount, sf.format, sf.colorSpace, ctx.extent, 1, + vk::ImageUsageFlagBits::eTransferDst, + vk::SharingMode::eExclusive, 0, nullptr, + caps.currentTransform, vk::CompositeAlphaFlagBitsKHR::eOpaque, + vk::PresentModeKHR::eFifo, VK_TRUE, {} + }); + ctx.images = ctx.device.getSwapchainImagesKHR(ctx.swapchain); + + ctx.pool = ctx.device.createCommandPool({vk::CommandPoolCreateFlagBits::eResetCommandBuffer, ctx.presentFam}); + ctx.cmd = ctx.device.allocateCommandBuffers({ctx.pool, vk::CommandBufferLevel::ePrimary, 1})[0]; + ctx.semImage = ctx.device.createSemaphore({}); + ctx.semDone = ctx.device.createSemaphore({}); + ctx.fence = ctx.device.createFence({}); + return ctx; +} + +bool windowOpen(const WindowContext &ctx) { + return ctx.wnd && !glfwWindowShouldClose(ctx.wnd); +} + +void drawBuffer(WindowContext &ctx, vk::Buffer src, uint32_t width, uint32_t height, vk::DeviceSize offset) { + glfwPollEvents(); + + auto acq = ctx.device.acquireNextImageKHR(ctx.swapchain, UINT64_MAX, ctx.semImage, {}); + if (acq.result == vk::Result::eSuboptimalKHR) {} // continue + if (acq.result == vk::Result::eErrorOutOfDateKHR) return; // ignore; recreate swapchain if you want + uint32_t idx = acq.value; + + ctx.cmd.reset({}); + ctx.cmd.begin({vk::CommandBufferUsageFlagBits::eOneTimeSubmit}); + + vk::ImageMemoryBarrier toDst({}, vk::AccessFlagBits::eTransferWrite, + vk::ImageLayout::eUndefined, vk::ImageLayout::eTransferDstOptimal, + VK_QUEUE_FAMILY_IGNORED, VK_QUEUE_FAMILY_IGNORED, + ctx.images[idx], {vk::ImageAspectFlagBits::eColor, 0,1, 0,1}); + ctx.cmd.pipelineBarrier(vk::PipelineStageFlagBits::eTopOfPipe, + vk::PipelineStageFlagBits::eTransfer, {}, {}, {}, toDst); + + vk::BufferImageCopy copy{}; + copy.bufferOffset = offset; + copy.imageSubresource = {vk::ImageAspectFlagBits::eColor, 0, 0, 1}; + copy.imageExtent = vk::Extent3D{ width, height, 1 }; + ctx.cmd.copyBufferToImage(src, ctx.images[idx], vk::ImageLayout::eTransferDstOptimal, 1, ©); + + vk::ImageMemoryBarrier toPresent(vk::AccessFlagBits::eTransferWrite, {}, + vk::ImageLayout::eTransferDstOptimal, vk::ImageLayout::ePresentSrcKHR, + VK_QUEUE_FAMILY_IGNORED, VK_QUEUE_FAMILY_IGNORED, + ctx.images[idx], {vk::ImageAspectFlagBits::eColor, 0,1, 0,1}); + ctx.cmd.pipelineBarrier(vk::PipelineStageFlagBits::eTransfer, + vk::PipelineStageFlagBits::eBottomOfPipe, {}, {}, {}, toPresent); + + ctx.cmd.end(); + + vk::PipelineStageFlags waitStage = vk::PipelineStageFlagBits::eTransfer; + + (void)ctx.device.resetFences(ctx.fence); + ctx.queue.submit({vk::SubmitInfo(1, &ctx.semImage, &waitStage, 1, &ctx.cmd, 1, &ctx.semDone)}, ctx.fence); + (void)ctx.device.waitForFences(ctx.fence, VK_TRUE, UINT64_MAX); + + try { + (void)ctx.queue.presentKHR({1, &ctx.semDone, 1, &ctx.swapchain, &idx}); + } catch (const vk::OutOfDateKHRError&) { + // ignore + } +} + +void drawBuffer(WindowContext &ctx, const Buffer &b, uint32_t w, uint32_t h, size_t offset) { + // optional sanity check: + if (offset + size_t(w)*size_t(h)*4 > b.size) throw std::out_of_range("buffer too small"); + drawBuffer(ctx, b.buffer, w, h, offset); +} diff --git a/examples/Slang/mandelbrot.py b/examples/Slang/mandelbrot.py new file mode 100644 index 00000000..66d236c0 --- /dev/null +++ b/examples/Slang/mandelbrot.py @@ -0,0 +1,37 @@ +import numpy as np +import TensorFrost as tf +from pathlib import Path + +ctx = tf.VulkanContext() + +W, H = 1024, 768 +win = tf.createWindow(ctx, W, H, "Mandelbrot") +fmt = int(win.format) +is_bgra = fmt in (44, 50) # VK_FORMAT_B8G8R8A8_UNORM / _SRGB + +pix = tf.createBuffer(ctx, W*H, 4, False) # uint32 pixels +params = tf.createBuffer(ctx, 8, 4, True) # 8 float32 params + +with open(Path(__file__).with_name('mandelbrot.slang'), 'r') as f: + hlsl = f.read() + +prog = tf.createComputeProgramFromSlang(ctx, "mandelbrot", hlsl, "csMain", roCount=1, rwCount=1) + +# view rectangle with aspect correction +xspan = 3.0 +yspan = xspan * (H / float(W)) +xmin, ymin = -2.0, -yspan * 0.5 +dx, dy = xspan / W, yspan / H +max_iter = 500.0 + +p = np.array([float(W), float(H), xmin, ymin, dx, dy, max_iter, 1.0 if is_bgra else 0.0], dtype=np.float32) +tf.setBufferData(ctx, params, p) + +try: + while tf.windowOpen(win): + tf.runProgram(ctx, prog, [params], [pix], W*H) + tf.drawBuffer(win, pix, W, H) +finally: + tf.destroyComputeProgram(ctx, prog) + tf.destroyBuffer(ctx, pix) + tf.destroyBuffer(ctx, params) diff --git a/examples/Slang/mandelbrot.slang b/examples/Slang/mandelbrot.slang new file mode 100644 index 00000000..4470166f --- /dev/null +++ b/examples/Slang/mandelbrot.slang @@ -0,0 +1,50 @@ +[[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer Pixels : register(u1, space0); + +static float3 palette(float t) { + return sin(6.3 * float3(0.0, 0.33, 0.67) * t); +} + +[numthreads(64,1,1)] +void csMain(uint3 tid : SV_DispatchThreadID) +{ + uint idx = tid.x; + int W = (int)(Params[0] + 0.5); + int H = (int)(Params[1] + 0.5); + uint N = (uint)(W * H); + if (idx >= N) return; + + int x = int(idx % (uint)W); + int y = int(idx / (uint)W); + + float xmin = Params[2], ymin = Params[3], dx = Params[4], dy = Params[5]; + int maxIter = (int)(Params[6] + 0.5); + bool isBGRA = (Params[7] > 0.5); + + float cx = xmin + float(x) * dx; + float cy = ymin + float(y) * dy; + + float zx = 0.0, zy = 0.0; + int i = 0; + [loop] + for (; i < maxIter; ++i) { + float zx2 = zx*zx - zy*zy + cx; + zy = 2.0*zx*zy + cy; + zx = zx2; + if (zx*zx + zy*zy > 4.0) break; + } + + float t = (float(i) - log2(log(sqrt(zx*zx + zy*zy))) + 4.0); + t = clamp(t / float(maxIter), 0.0, 1.0); + + float3 rgb = palette(t); + float4 c = float4(rgb, 1.0); + + uint r = (uint)round(saturate(c.r) * 255.0); + uint g = (uint)round(saturate(c.g) * 255.0); + uint b = (uint)round(saturate(c.b) * 255.0); + uint a = (uint)round(saturate(c.a) * 255.0); + uint packed = isBGRA ? (b | (g<<8) | (r<<16) | (a<<24)) + : (r | (g<<8) | (b<<16) | (a<<24)); + Pixels[idx] = packed; +} \ No newline at end of file diff --git a/examples/debug.py b/examples/debug.py index de3cdd48..fac28890 100644 --- a/examples/debug.py +++ b/examples/debug.py @@ -1,48 +1,86 @@ import numpy as np import TensorFrost as tf -ctx = tf.VulkanContext() - -N = 1024 -aBuf = tf.createBuffer(ctx, N, 4, True) -bBuf = tf.createBuffer(ctx, N, 4, True) -outBuf = tf.createBuffer(ctx, N, 4, False) - -# write inputs via mapped views -a_map = ctx.device.mapMemory(aBuf.memory, 0, aBuf.size) -b_map = ctx.device.mapMemory(bBuf.memory, 0, bBuf.size) -np.frombuffer(a_map, dtype=np.float32)[:] = np.arange(N, dtype=np.float32) -np.frombuffer(b_map, dtype=np.float32)[:] = 2 * np.arange(N, dtype=np.float32) -ctx.device.unmapMemory(aBuf.memory) -ctx.device.unmapMemory(bBuf.memory) - -code = r''' -[[vk::binding(2,0)]] RWStructuredBuffer C; -[[vk::binding(0,0)]] StructuredBuffer A; -[[vk::binding(1,0)]] StructuredBuffer B; - -[shader("compute")] [numthreads(64,1,1)] -void computeMain(uint3 tid: SV_DispatchThreadID) { - C[tid.x] = 2.0f*A[tid.x] + B[tid.x]; +# GLSL: 1D dispatch (local_size_x=64). Pixels are packed with packUnorm4x8. +glsl = r""" +#version 450 +layout(local_size_x = 64) in; + +layout(std430, binding = 0) readonly buffer Params { float p[]; }; // [w,h,xmin,ymin,dx,dy,maxIter,isBGRA] +layout(std430, binding = 1) writeonly buffer Pixels { uint out_u32[]; }; + +vec3 palette(float t) { + // simple smooth palette + return vec3(0.5 + 0.5*cos(6.28318*(vec3(0.0,0.33,0.67)+t))); } -''' -prog = tf.createComputeProgramFromSlang( - ctx, "vecadd", code, "computeMain", - [aBuf, bBuf], [outBuf] -) +void main() { + uint idx1D = gl_GlobalInvocationID.x; + int W = int(p[0] + 0.5), H = int(p[1] + 0.5); + uint N = uint(W*H); + if (idx1D >= N) return; + + int x = int(idx1D % uint(W)); + int y = int(idx1D / uint(W)); + + float xmin = p[2], ymin = p[3], dx = p[4], dy = p[5]; + int maxIter = int(p[6] + 0.5); + bool isBGRA = (p[7] > 0.5); + + float cx = xmin + float(x) * dx; + float cy = ymin + float(y) * dy; + + float zx = 0.0, zy = 0.0; + int i = 0; + for (; i < maxIter; ++i) { + float zx2 = zx*zx - zy*zy + cx; + float zy2 = 2.0*zx*zy + cy; + zx = zx2; zy = zy2; + if (zx*zx + zy*zy > 4.0) break; + } + + float t = (i == maxIter) ? 0.0 : + float(i) - log2(log(length(vec2(zx,zy)))) + 4.0; + t = clamp(t / float(maxIter), 0.0, 1.0); + + vec3 rgb = palette(t); + vec4 c = vec4(rgb, 1.0); + uint packed = isBGRA ? packUnorm4x8(c.bgra) : packUnorm4x8(c); + out_u32[idx1D] = packed; +} +""" + +def main(): + ctx = tf.VulkanContext() + + W, H = 1024, 768 + win = tf.createWindow(ctx, W, H, "Mandelbrot (compute → buffer → swapchain)") + fmt = int(win.format) + is_bgra = fmt in (44, 50) # VK_FORMAT_B8G8R8A8_UNORM / _SRGB + + pix = tf.createBuffer(ctx, W*H, 4, False) # uint32 pixels + params = tf.createBuffer(ctx, 8, 4, True) # 8 float32 params + + prog = tf.createComputeProgramFromGLSL(ctx, glsl, roCount=1, rwCount=1) -tf.runProgram(ctx, prog, N) + # view rectangle with aspect correction + xspan = 3.0 + yspan = xspan * (H / float(W)) + xmin, ymin = -2.0, -yspan * 0.5 + dx, dy = xspan / W, yspan / H + max_iter = 500.0 -# read back -out_map = ctx.device.mapMemory(outBuf.memory, 0, outBuf.size) -out = np.frombuffer(out_map, dtype=np.float32).copy() -ctx.device.unmapMemory(outBuf.memory) + p = np.array([float(W), float(H), xmin, ymin, dx, dy, max_iter, 1.0 if is_bgra else 0.0], dtype=np.float32) + tf.setBufferData(ctx, params, p) -ok = np.allclose(out, 4 * np.arange(N, dtype=np.float32)) -print("Compute result is", "correct" if ok else "incorrect") + try: + while tf.windowOpen(win): + tf.runProgram(ctx, prog, [params], [pix], W*H) + tf.drawBuffer(win, pix, W, H) + finally: + tf.destroyComputeProgram(ctx, prog) + tf.destroyBuffer(ctx, pix) + tf.destroyBuffer(ctx, params) -tf.destroyComputeProgram(ctx, prog) -tf.destroyBuffer(ctx, aBuf) -tf.destroyBuffer(ctx, bBuf) -tf.destroyBuffer(ctx, outBuf) \ No newline at end of file +if __name__ == "__main__": + main() \ No newline at end of file From 6b848f5c9b6c9a39c6fb35874dfab109cacc6d0d Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Wed, 5 Nov 2025 23:26:20 +0100 Subject: [PATCH 21/44] Refactor vulkan interface --- TensorFrost/PybindModule.cpp | 99 +----- .../include/Definitions/VulkanBindings.h | 9 + .../src/Definitions/VulkanBindings.cpp | 94 ++++++ .../src/Definitions/VulkanInterface.cpp | 301 ++++++++++++++++++ TensorFrost/src/Definitions/VulkanInterface.h | 123 +++++++ examples/Slang/mandelbrot.py | 22 +- examples/debug.py | 22 +- setup_python_env.cmd | 65 ++++ tests/vulkan_window_test.py | 54 ++++ 9 files changed, 666 insertions(+), 123 deletions(-) create mode 100644 TensorFrost/include/Definitions/VulkanBindings.h create mode 100644 TensorFrost/src/Definitions/VulkanBindings.cpp create mode 100644 TensorFrost/src/Definitions/VulkanInterface.cpp create mode 100644 TensorFrost/src/Definitions/VulkanInterface.h create mode 100644 setup_python_env.cmd create mode 100644 tests/vulkan_window_test.py diff --git a/TensorFrost/PybindModule.cpp b/TensorFrost/PybindModule.cpp index 35f8468c..dc4331ea 100644 --- a/TensorFrost/PybindModule.cpp +++ b/TensorFrost/PybindModule.cpp @@ -10,6 +10,7 @@ #include #include "TensorFrost.h" +#include "Definitions/VulkanBindings.h" namespace py = pybind11; @@ -24,15 +25,6 @@ namespace TensorFrost { // void ScopeDefinitions(py::module& m, py::class_& py_tensor); // void ModuleDefinitions(py::module& m); -static bool is_c_contig(const py::buffer_info& i) { - py::ssize_t stride = i.itemsize; - for (py::ssize_t d = i.ndim - 1; d >= 0; --d) { - if (i.strides[d] != stride) return false; - stride *= i.shape[d]; - } - return true; -} - PYBIND11_MODULE(TensorFrost, m) { m.doc() = "TensorFrost library"; // auto data_type = py::enum_(m, "TFType"); @@ -184,94 +176,7 @@ PYBIND11_MODULE(TensorFrost, m) { program.Compile(); py::print(program.DebugPrint()); - py::class_(m, "VulkanContext").def(py::init<>()); - - py::class_(m, "Buffer") - .def_property_readonly("size", [](const Buffer& b){ return b.size; }); - - py::class_(m, "ComputeProgram"); - - m.def("createBuffer", &createBuffer, - py::arg("ctx"), py::arg("count"), py::arg("dtypeSize"), py::arg("readOnly"), - py::return_value_policy::move); - - m.def("destroyBuffer", &destroyBuffer, py::arg("ctx"), py::arg("buf")); - - m.def("createComputeProgramFromGLSL", &createComputeProgramFromGLSL, - py::arg("ctx"), py::arg("glsl_source"), py::arg("roCount"), py::arg("rwCount"), - py::return_value_policy::move); - - m.def("createComputeProgramFromSlang", &createComputeProgramFromSlang, - py::arg("ctx"), py::arg("moduleName"), py::arg("source"), py::arg("entry"), - py::arg("roCount"), py::arg("rwCount"), - py::return_value_policy::move); - - m.def("destroyComputeProgram", &destroyComputeProgram, py::arg("ctx"), py::arg("prog")); - - m.def("runProgram", &runProgram, - py::arg("ctx"), py::arg("prog"), - py::arg("readonlyBuffers"), py::arg("readwriteBuffers"), - py::arg("numInvocations"), - py::call_guard()); - - // --- numpy I/O --- - m.def("setBufferData", - [](VulkanContext& ctx, Buffer& buf, py::array arr, size_t offset) { - auto info = arr.request(); // GIL held - if (!is_c_contig(info)) throw std::runtime_error("array must be C-contiguous"); - size_t nbytes = static_cast(info.size) * static_cast(info.itemsize); - if (offset + nbytes > buf.size) throw std::out_of_range("write out of range"); - py::gil_scoped_release release; - setBufferData(ctx, buf, info.ptr, nbytes, offset); - }, - py::arg("ctx"), py::arg("buf"), py::arg("array"), py::arg("offset") = 0); - - m.def("getBufferData", - [](VulkanContext& ctx, const Buffer& buf, py::dtype dt, size_t count, size_t offset) { - size_t itemsize = dt.attr("itemsize").cast(); // GIL held - size_t nbytes = count * itemsize; - if (offset + nbytes > buf.size) throw std::out_of_range("read out of range"); - - py::array out(dt, py::array::ShapeContainer{ static_cast(count) }); - auto info = out.request(); // contiguous by default - - { py::gil_scoped_release release; - getBufferData(ctx, buf, info.ptr, nbytes, offset); - } - return out; - }, - py::arg("ctx"), py::arg("buf"), py::arg("dtype"), py::arg("count"), py::arg("offset") = 0); - - m.def("getBufferData_into", - [](VulkanContext& ctx, const Buffer& buf, py::array out, size_t offset) { - auto info = out.request(); // GIL held - if (info.readonly) throw std::runtime_error("output array must be writeable"); - if (!is_c_contig(info)) throw std::runtime_error("output array must be C-contiguous"); - size_t nbytes = static_cast(info.size) * static_cast(info.itemsize); - if (offset + nbytes > buf.size) throw std::out_of_range("read out of range"); - py::gil_scoped_release release; - getBufferData(ctx, buf, info.ptr, nbytes, offset); - }, - py::arg("ctx"), py::arg("buf"), py::arg("out"), py::arg("offset") = 0); - - py::class_(m, "WindowContext") - .def_property_readonly("size", - [](const WindowContext& c){ return py::make_tuple(c.extent.width, c.extent.height); }) - .def_property_readonly("format", - [](const WindowContext& c){ return static_cast(c.format); }); - - m.def("createWindow", - static_cast(&createWindow), - py::arg("ctx"), py::arg("width"), py::arg("height"), py::arg("title"), - py::return_value_policy::move, - py::keep_alive<0,1>(), // keep ctx alive as long as WindowContext lives - py::call_guard()); - - m.def("windowOpen", &windowOpen, py::arg("ctx")); - m.def("drawBuffer", - static_cast(&drawBuffer), - py::arg("ctx"), py::arg("buffer"), py::arg("width"), py::arg("height"), py::arg("offset") = 0, - py::call_guard()); + VulkanDefinitions(m); // VulkanContext ctx; // // const size_t N = 1024; diff --git a/TensorFrost/include/Definitions/VulkanBindings.h b/TensorFrost/include/Definitions/VulkanBindings.h new file mode 100644 index 00000000..14159ddc --- /dev/null +++ b/TensorFrost/include/Definitions/VulkanBindings.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace TensorFrost { + +void VulkanDefinitions(pybind11::module_& m); + +} // namespace TensorFrost diff --git a/TensorFrost/src/Definitions/VulkanBindings.cpp b/TensorFrost/src/Definitions/VulkanBindings.cpp new file mode 100644 index 00000000..2ffacd24 --- /dev/null +++ b/TensorFrost/src/Definitions/VulkanBindings.cpp @@ -0,0 +1,94 @@ +#include "Definitions/VulkanBindings.h" +#include "VulkanInterface.h" + +#include +#include + +namespace py = pybind11; + +namespace TensorFrost { + +void VulkanDefinitions(py::module_& m) { + py::class_(m, "Buffer", "Vulkan-backed storage buffer exposed to Python.") + .def(py::init(), + py::arg("count"), py::arg("dtype_size"), py::arg("read_only") = false, + "Create a buffer sized for `count` elements of size `dtype_size`.") + .def_property_readonly("size", &PyBuffer::byteSize, "Total size of the buffer in bytes.") + .def_property_readonly("count", &PyBuffer::elementCapacity, + "Maximum number of elements the buffer can hold for the configured dtype size.") + .def_property_readonly("read_only", &PyBuffer::isReadOnly, + "Whether the buffer is flagged as read-only for compute kernels.") + .def("setData", &PyBuffer::setData, py::arg("data"), py::arg("offset") = 0, + "Upload data from a NumPy array or bytes-like object into the buffer.") + .def("getData", + [](const PyBuffer& self, const py::object& dtype, const py::object& count, size_t offset) { + return self.getData(dtype, count, offset); + }, + py::arg("dtype") = py::none(), py::arg("count") = py::none(), py::arg("offset") = 0, + "Download data from the buffer into a newly allocated NumPy array.") + .def("release", &PyBuffer::release, + "Explicitly destroy the underlying Vulkan buffer and release its memory."); + + m.def("createBuffer", + [](size_t count, size_t dtypeSize, bool readOnly) { + return PyBuffer(count, dtypeSize, readOnly); + }, + py::arg("count"), py::arg("dtype_size"), py::arg("read_only") = false, + py::return_value_policy::move, + "Convenience helper to construct a :class:`Buffer` without calling the class directly."); + + py::class_(m, "ComputeProgram", + "Compiled compute pipeline that can be dispatched on the GPU.") + .def_property_readonly("readonly_count", &PyComputeProgram::readonlyCount, + "Number of read-only storage buffers expected by the program.") + .def_property_readonly("readwrite_count", &PyComputeProgram::readwriteCount, + "Number of read-write storage buffers expected by the program.") + .def("run", &PyComputeProgram::run, + py::arg("readonly_buffers"), py::arg("readwrite_buffers"), py::arg("num_invocations"), + "Dispatch the compute pipeline with the provided buffers and invocation count.") + .def("release", &PyComputeProgram::release, + "Explicitly destroy the underlying Vulkan pipeline and associated resources."); + + m.def("createComputeProgramFromGLSL", + [](const std::string& source, uint32_t roCount, uint32_t rwCount) { + return MakeComputeProgramFromGLSL(source, roCount, rwCount); + }, + py::arg("glsl_source"), py::arg("ro_count"), py::arg("rw_count"), + py::return_value_policy::move, + "Compile a compute shader written in GLSL to SPIR-V and wrap it in a :class:`ComputeProgram`."); + + m.def("createComputeProgramFromSlang", + [](const std::string& moduleName, const std::string& source, const std::string& entry, uint32_t roCount, uint32_t rwCount) { + return MakeComputeProgramFromSlang(moduleName, source, entry, roCount, rwCount); + }, + py::arg("module_name"), py::arg("source"), py::arg("entry"), + py::arg("ro_count"), py::arg("rw_count"), + py::return_value_policy::move, + "Compile a Slang module to SPIR-V and wrap it in a :class:`ComputeProgram`."); + + py::class_(m, "Window", + "GLFW-backed Vulkan swapchain window for presenting compute output.") + .def(py::init(), py::arg("width"), py::arg("height"), py::arg("title"), + "Create a window with an attached Vulkan swapchain.") + .def_property_readonly("size", &PyWindow::size, + "Current window extent as a tuple ``(width, height)``.") + .def_property_readonly("format", &PyWindow::format, + "Pixel format of the swapchain image as a Vulkan enum value.") + .def("isOpen", &PyWindow::isOpen, + "Return ``True`` while the window is alive and the user has not closed it.") + .def("drawBuffer", &PyWindow::drawBuffer, + py::arg("buffer"), py::arg("width"), py::arg("height"), py::arg("offset") = 0, + "Copy a buffer of packed pixels onto the swapchain.") + .def("close", &PyWindow::close, + "Destroy the window and release its swapchain resources."); + + m.def("createWindow", + [](int width, int height, const std::string& title) { + return PyWindow(width, height, title); + }, + py::arg("width"), py::arg("height"), py::arg("title"), + py::return_value_policy::move, + "Convenience helper to construct a :class:`Window` without calling the class directly."); +} + +} // namespace TensorFrost diff --git a/TensorFrost/src/Definitions/VulkanInterface.cpp b/TensorFrost/src/Definitions/VulkanInterface.cpp new file mode 100644 index 00000000..591b8dc6 --- /dev/null +++ b/TensorFrost/src/Definitions/VulkanInterface.cpp @@ -0,0 +1,301 @@ +#include "VulkanInterface.h" + +#include +#include + +#include +#include +#include +#include + +#include "Backend/Vulkan.h" +#include "Backend/Window.h" + +namespace py = pybind11; + +namespace TensorFrost { +namespace { + +VulkanContext& getContext() { + static VulkanContext ctx{}; + return ctx; +} + +bool isCContiguous(const py::buffer_info& info) { + py::ssize_t stride = info.itemsize; + for (py::ssize_t d = info.ndim - 1; d >= 0; --d) { + if (info.strides[d] != stride) return false; + stride *= info.shape[d]; + } + return true; +} + +} // namespace + +PyBuffer::PyBuffer(size_t count, size_t dtypeSize, bool readOnly) + : ctx_(&getContext()), + buffer_(createBuffer(*ctx_, count, dtypeSize, readOnly)), + readOnly_(readOnly), + dtypeSizeHint_(dtypeSize ? dtypeSize : 1), + lastCount_(count), + lastDtype_(py::none()) {} + +PyBuffer::~PyBuffer() { release(); } + +PyBuffer::PyBuffer(PyBuffer&& other) noexcept { moveFrom(std::move(other)); } + +PyBuffer& PyBuffer::operator=(PyBuffer&& other) noexcept { + if (this != &other) { + release(); + moveFrom(std::move(other)); + } + return *this; +} + +bool PyBuffer::valid() const { return ctx_ && buffer_.buffer; } + +size_t PyBuffer::byteSize() const { return buffer_.size; } + +size_t PyBuffer::elementCapacity() const { return dtypeSizeHint_ ? buffer_.size / dtypeSizeHint_ : buffer_.size; } + +bool PyBuffer::isReadOnly() const { return readOnly_; } + +void PyBuffer::release() { + if (ctx_ && buffer_.buffer) { + destroyBuffer(*ctx_, buffer_); + } + buffer_ = {}; + ctx_ = nullptr; + lastDtype_ = py::none(); + lastCount_ = 0; + dtypeSizeHint_ = 0; +} + +void PyBuffer::setData(const py::array& array, size_t offset) { + ensureValid(); + auto info = array.request(); + if (!isCContiguous(info)) throw std::runtime_error("array must be C-contiguous"); + size_t nbytes = static_cast(info.size) * static_cast(info.itemsize); + if (offset + nbytes > buffer_.size) throw std::out_of_range("write out of range"); + { + py::gil_scoped_release release; + setBufferData(*ctx_, buffer_, info.ptr, nbytes, offset); + } + lastDtype_ = array.dtype(); + lastCount_ = static_cast(info.size); + dtypeSizeHint_ = static_cast(info.itemsize ? info.itemsize : 1); +} + +py::array PyBuffer::getData(const py::object& dtypeArg, const py::object& countArg, size_t offset) const { + ensureValid(); + if (offset > buffer_.size) throw std::out_of_range("offset out of range"); + + py::dtype dtype = resolveDtype(dtypeArg); + size_t itemsize = dtype.attr("itemsize").cast(); + if (itemsize == 0) throw std::runtime_error("dtype itemsize cannot be zero"); + + size_t available = buffer_.size - offset; + size_t count = resolveCount(countArg, itemsize, available); + size_t nbytes = count * itemsize; + if (offset + nbytes > buffer_.size) throw std::out_of_range("read out of range"); + + py::array out(dtype, py::array::ShapeContainer{ static_cast(count) }); + auto info = out.request(); + { + py::gil_scoped_release release; + getBufferData(*ctx_, buffer_, info.ptr, nbytes, offset); + } + return out; +} + +Buffer& PyBuffer::raw() { + ensureValid(); + return buffer_; +} + +const Buffer& PyBuffer::raw() const { + ensureValid(); + return buffer_; +} + +void PyBuffer::ensureValid() const { + if (!valid()) throw std::runtime_error("Buffer has been released"); +} + +void PyBuffer::moveFrom(PyBuffer&& other) { + ctx_ = other.ctx_; + buffer_ = other.buffer_; + readOnly_ = other.readOnly_; + dtypeSizeHint_ = other.dtypeSizeHint_; + lastCount_ = other.lastCount_; + lastDtype_ = std::move(other.lastDtype_); + other.ctx_ = nullptr; + other.buffer_ = {}; + other.lastDtype_ = py::none(); +} + +py::dtype PyBuffer::resolveDtype(const py::object& dtypeArg) const { + if (!dtypeArg.is_none()) { + return py::reinterpret_borrow(dtypeArg); + } + if (!lastDtype_.is_none()) { + return py::reinterpret_borrow(lastDtype_); + } + switch (dtypeSizeHint_) { + case 2: return py::dtype::of(); + case 4: return py::dtype::of(); + case 8: return py::dtype::of(); + default: return py::dtype::of(); + } +} + +size_t PyBuffer::resolveCount(const py::object& countArg, size_t itemsize, size_t available) const { + if (!countArg.is_none()) { + return countArg.cast(); + } + if (!lastDtype_.is_none() && itemsize == dtypeSizeHint_ && lastCount_ != 0) { + return std::min(lastCount_, available / itemsize); + } + return available / itemsize; +} + +PyComputeProgram::PyComputeProgram(ComputeProgram&& prog) + : ctx_(&getContext()), program_(std::move(prog)) {} + +PyComputeProgram::~PyComputeProgram() { release(); } + +PyComputeProgram::PyComputeProgram(PyComputeProgram&& other) noexcept { moveFrom(std::move(other)); } + +PyComputeProgram& PyComputeProgram::operator=(PyComputeProgram&& other) noexcept { + if (this != &other) { + release(); + moveFrom(std::move(other)); + } + return *this; +} + +void PyComputeProgram::run(const py::iterable& readonlyBuffers, + const py::iterable& readwriteBuffers, + uint32_t numInvocations) { + ensureValid(); + std::vector ro; + std::vector rw; + collectBuffers(readonlyBuffers, ro, "readonly"); + collectBuffers(readwriteBuffers, rw, "readwrite"); + if (ro.size() != program_.numRO || rw.size() != program_.numRW) { + throw std::runtime_error("buffer count does not match program layout"); + } + py::gil_scoped_release release; + runProgram(*ctx_, program_, ro, rw, numInvocations); +} + +void PyComputeProgram::release() { + if (ctx_ && program_.pipeline) { + destroyComputeProgram(*ctx_, program_); + } + program_ = {}; + ctx_ = nullptr; +} + +uint32_t PyComputeProgram::readonlyCount() const { return program_.numRO; } + +uint32_t PyComputeProgram::readwriteCount() const { return program_.numRW; } + +void PyComputeProgram::ensureValid() const { + if (!ctx_ || !program_.pipeline) { + throw std::runtime_error("ComputeProgram has been released"); + } +} + +void PyComputeProgram::collectBuffers(const py::iterable& items, + std::vector& out, + const char* label) { + out.clear(); + for (auto obj : items) { + try { + py::handle handle(obj); + auto* buf = handle.cast(); + if (!buf) { + throw py::cast_error("null buffer pointer"); + } + out.push_back(&buf->raw()); + } catch (const py::cast_error&) { + throw std::runtime_error(std::string("expected Buffer in ") + label + " list"); + } + } +} + +void PyComputeProgram::moveFrom(PyComputeProgram&& other) { + ctx_ = other.ctx_; + program_ = other.program_; + other.ctx_ = nullptr; + other.program_ = {}; +} + +PyWindow::PyWindow(int width, int height, const std::string& title) + : ctx_(&getContext()), window_(createWindow(*ctx_, width, height, title.c_str())) {} + +PyWindow::~PyWindow() = default; + +PyWindow::PyWindow(PyWindow&& other) noexcept { moveFrom(std::move(other)); } + +PyWindow& PyWindow::operator=(PyWindow&& other) noexcept { + if (this != &other) { + moveFrom(std::move(other)); + } + return *this; +} + +bool PyWindow::isOpen() const { + ensureValid(); + return windowOpen(window_); +} + +void PyWindow::drawBuffer(const PyBuffer& buffer, uint32_t width, uint32_t height, size_t offset) { + ensureValid(); + py::gil_scoped_release release; + ::drawBuffer(window_, buffer.raw(), width, height, offset); +} + +py::tuple PyWindow::size() const { + ensureValid(); + return py::make_tuple(window_.extent.width, window_.extent.height); +} + +int PyWindow::format() const { + ensureValid(); + return static_cast(window_.format); +} + +void PyWindow::close() { + window_ = {}; + ctx_ = nullptr; +} + +void PyWindow::ensureValid() const { + if (!window_.wnd) { + throw std::runtime_error("Window has been closed"); + } +} + +void PyWindow::moveFrom(PyWindow&& other) { + ctx_ = other.ctx_; + window_ = std::move(other.window_); + other.ctx_ = nullptr; +} + +PyComputeProgram MakeComputeProgramFromGLSL(const std::string& source, + uint32_t roCount, + uint32_t rwCount) { + return PyComputeProgram(createComputeProgramFromGLSL(getContext(), source, roCount, rwCount)); +} + +PyComputeProgram MakeComputeProgramFromSlang(const std::string& moduleName, + const std::string& source, + const std::string& entry, + uint32_t roCount, + uint32_t rwCount) { + return PyComputeProgram(createComputeProgramFromSlang(getContext(), moduleName, source, entry, roCount, rwCount)); +} + +} // namespace TensorFrost diff --git a/TensorFrost/src/Definitions/VulkanInterface.h b/TensorFrost/src/Definitions/VulkanInterface.h new file mode 100644 index 00000000..b172f25c --- /dev/null +++ b/TensorFrost/src/Definitions/VulkanInterface.h @@ -0,0 +1,123 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +#include "Backend/Vulkan.h" +#include "Backend/Window.h" + +namespace TensorFrost { + +class PyBuffer { +public: + PyBuffer(size_t count, size_t dtypeSize, bool readOnly); + ~PyBuffer(); + + PyBuffer(const PyBuffer&) = delete; + PyBuffer& operator=(const PyBuffer&) = delete; + + PyBuffer(PyBuffer&& other) noexcept; + PyBuffer& operator=(PyBuffer&& other) noexcept; + + bool valid() const; + size_t byteSize() const; + size_t elementCapacity() const; + bool isReadOnly() const; + + void release(); + + void setData(const pybind11::array& array, size_t offset); + pybind11::array getData(const pybind11::object& dtypeArg, + const pybind11::object& countArg, + size_t offset) const; + + Buffer& raw(); + const Buffer& raw() const; + +private: + void ensureValid() const; + void moveFrom(PyBuffer&& other); + pybind11::dtype resolveDtype(const pybind11::object& dtypeArg) const; + size_t resolveCount(const pybind11::object& countArg, size_t itemsize, size_t available) const; + + VulkanContext* ctx_{}; + Buffer buffer_{}; + bool readOnly_{}; + size_t dtypeSizeHint_{}; + size_t lastCount_{}; + pybind11::object lastDtype_; +}; + +class PyComputeProgram { +public: + explicit PyComputeProgram(ComputeProgram&& prog); + ~PyComputeProgram(); + + PyComputeProgram(const PyComputeProgram&) = delete; + PyComputeProgram& operator=(const PyComputeProgram&) = delete; + + PyComputeProgram(PyComputeProgram&& other) noexcept; + PyComputeProgram& operator=(PyComputeProgram&& other) noexcept; + + void run(const pybind11::iterable& readonlyBuffers, + const pybind11::iterable& readwriteBuffers, + uint32_t numInvocations); + + void release(); + + uint32_t readonlyCount() const; + uint32_t readwriteCount() const; + +private: + void ensureValid() const; + static void collectBuffers(const pybind11::iterable& items, + std::vector& out, + const char* label); + void moveFrom(PyComputeProgram&& other); + + VulkanContext* ctx_{}; + ComputeProgram program_{}; +}; + +class PyWindow { +public: + PyWindow(int width, int height, const std::string& title); + ~PyWindow(); + + PyWindow(const PyWindow&) = delete; + PyWindow& operator=(const PyWindow&) = delete; + + PyWindow(PyWindow&& other) noexcept; + PyWindow& operator=(PyWindow&& other) noexcept; + + bool isOpen() const; + void drawBuffer(const PyBuffer& buffer, uint32_t width, uint32_t height, size_t offset); + pybind11::tuple size() const; + int format() const; + void close(); + +private: + void ensureValid() const; + void moveFrom(PyWindow&& other); + + VulkanContext* ctx_{}; + WindowContext window_{}; +}; + +PyComputeProgram MakeComputeProgramFromGLSL(const std::string& source, + uint32_t roCount, + uint32_t rwCount); + +PyComputeProgram MakeComputeProgramFromSlang(const std::string& moduleName, + const std::string& source, + const std::string& entry, + uint32_t roCount, + uint32_t rwCount); + +} // namespace TensorFrost diff --git a/examples/Slang/mandelbrot.py b/examples/Slang/mandelbrot.py index 66d236c0..2da3e2d0 100644 --- a/examples/Slang/mandelbrot.py +++ b/examples/Slang/mandelbrot.py @@ -2,20 +2,18 @@ import TensorFrost as tf from pathlib import Path -ctx = tf.VulkanContext() - W, H = 1024, 768 -win = tf.createWindow(ctx, W, H, "Mandelbrot") +win = tf.createWindow(W, H, "Mandelbrot") fmt = int(win.format) is_bgra = fmt in (44, 50) # VK_FORMAT_B8G8R8A8_UNORM / _SRGB -pix = tf.createBuffer(ctx, W*H, 4, False) # uint32 pixels -params = tf.createBuffer(ctx, 8, 4, True) # 8 float32 params +pix = tf.createBuffer(W*H, 4, False) # uint32 pixels +params = tf.createBuffer(8, 4, True) # 8 float32 params with open(Path(__file__).with_name('mandelbrot.slang'), 'r') as f: hlsl = f.read() -prog = tf.createComputeProgramFromSlang(ctx, "mandelbrot", hlsl, "csMain", roCount=1, rwCount=1) +prog = tf.createComputeProgramFromSlang("mandelbrot", hlsl, "csMain", ro_count=1, rw_count=1) # view rectangle with aspect correction xspan = 3.0 @@ -25,13 +23,11 @@ max_iter = 500.0 p = np.array([float(W), float(H), xmin, ymin, dx, dy, max_iter, 1.0 if is_bgra else 0.0], dtype=np.float32) -tf.setBufferData(ctx, params, p) +params.setData(p) try: - while tf.windowOpen(win): - tf.runProgram(ctx, prog, [params], [pix], W*H) - tf.drawBuffer(win, pix, W, H) + while win.isOpen(): + prog.run([params], [pix], W*H) + win.drawBuffer(pix, W, H) finally: - tf.destroyComputeProgram(ctx, prog) - tf.destroyBuffer(ctx, pix) - tf.destroyBuffer(ctx, params) + win.close() diff --git a/examples/debug.py b/examples/debug.py index fac28890..0ad93ce9 100644 --- a/examples/debug.py +++ b/examples/debug.py @@ -51,17 +51,15 @@ """ def main(): - ctx = tf.VulkanContext() - W, H = 1024, 768 - win = tf.createWindow(ctx, W, H, "Mandelbrot (compute → buffer → swapchain)") + win = tf.createWindow(W, H, "Mandelbrot (compute → buffer → swapchain)") fmt = int(win.format) is_bgra = fmt in (44, 50) # VK_FORMAT_B8G8R8A8_UNORM / _SRGB - pix = tf.createBuffer(ctx, W*H, 4, False) # uint32 pixels - params = tf.createBuffer(ctx, 8, 4, True) # 8 float32 params + pix = tf.createBuffer(W*H, 4, False) # uint32 pixels + params = tf.createBuffer(8, 4, True) # 8 float32 params - prog = tf.createComputeProgramFromGLSL(ctx, glsl, roCount=1, rwCount=1) + prog = tf.createComputeProgramFromGLSL(glsl, ro_count=1, rw_count=1) # view rectangle with aspect correction xspan = 3.0 @@ -71,16 +69,14 @@ def main(): max_iter = 500.0 p = np.array([float(W), float(H), xmin, ymin, dx, dy, max_iter, 1.0 if is_bgra else 0.0], dtype=np.float32) - tf.setBufferData(ctx, params, p) + params.setData(p) try: - while tf.windowOpen(win): - tf.runProgram(ctx, prog, [params], [pix], W*H) - tf.drawBuffer(win, pix, W, H) + while win.isOpen(): + prog.run([params], [pix], W*H) + win.drawBuffer(pix, W, H) finally: - tf.destroyComputeProgram(ctx, prog) - tf.destroyBuffer(ctx, pix) - tf.destroyBuffer(ctx, params) + win.close() if __name__ == "__main__": main() \ No newline at end of file diff --git a/setup_python_env.cmd b/setup_python_env.cmd new file mode 100644 index 00000000..eaf88469 --- /dev/null +++ b/setup_python_env.cmd @@ -0,0 +1,65 @@ +@echo off +setlocal enableextensions + +set "REQUESTED_VERSION=%~1" +if defined REQUESTED_VERSION ( + set "PYTHON_VERSION=%REQUESTED_VERSION%" + set "VERSION_WAS_EXPLICIT=1" +) else ( + set "PYTHON_VERSION=3.12" + set "VERSION_WAS_EXPLICIT=0" +) + +set "SCRIPT_DIR=%~dp0" +if not defined SCRIPT_DIR set "SCRIPT_DIR=.\" +for %%I in ("%SCRIPT_DIR%.") do set "REPO_ROOT=%%~fI" + +set "VENV_DIR=%REPO_ROOT%\.venv" +set "VENV_PYTHON=%VENV_DIR%\Scripts\python.exe" + +set "CREATE_VENV_CMD=py -%PYTHON_VERSION% -m venv" +py -%PYTHON_VERSION% -c "import sys" >nul 2>&1 +if errorlevel 1 ( + if "%VERSION_WAS_EXPLICIT%"=="1" ( + echo [TensorFrost] Python %PYTHON_VERSION% is not available via the launcher. Install it or pass a different version. + exit /b 1 + ) else ( + echo [TensorFrost] Python %PYTHON_VERSION% not found; falling back to default interpreter for venv creation. + set "CREATE_VENV_CMD=py -m venv" + ) +) + +if exist "%VENV_PYTHON%" ( + echo [TensorFrost] Using existing virtual environment at "%VENV_DIR%" +) else ( + echo [TensorFrost] Creating virtual environment at "%VENV_DIR%" + %CREATE_VENV_CMD% "%VENV_DIR%" + if errorlevel 1 ( + echo [TensorFrost] Failed to create virtual environment. + exit /b 1 + ) +) + +if not exist "%VENV_PYTHON%" ( + echo [TensorFrost] Could not find python interpreter in "%VENV_DIR%". + exit /b 1 +) + +echo [TensorFrost] Upgrading pip inside the virtual environment... +"%VENV_PYTHON%" -m pip install --upgrade pip +if errorlevel 1 ( + echo [TensorFrost] Failed to upgrade pip. + exit /b 1 +) + +echo [TensorFrost] Installing TensorFrost in editable mode (verbose)... +"%VENV_PYTHON%" -m pip install -v -e "%REPO_ROOT%\Python" +if errorlevel 1 ( + echo [TensorFrost] Editable install failed. + exit /b 1 +) + +echo [TensorFrost] TensorFrost development environment ready. +echo [TensorFrost] Activate it with "%VENV_DIR%\Scripts\activate.bat" or "pwsh %VENV_DIR%\Scripts\Activate.ps1". + +exit /b 0 diff --git a/tests/vulkan_window_test.py b/tests/vulkan_window_test.py new file mode 100644 index 00000000..74fc6672 --- /dev/null +++ b/tests/vulkan_window_test.py @@ -0,0 +1,54 @@ +import unittest +from contextlib import ExitStack + +import numpy as np + +import TensorFrost as tf + + +_SIMPLE_GLSL = """#version 450 +layout(local_size_x = 64) in; +layout(set = 0, binding = 0) buffer Pixels { uint data[]; }; + +void main() { + uint idx = gl_GlobalInvocationID.x; + if (idx >= data.length()) return; + data[idx] = 0xff3366ff; +} +""" + + +class VulkanWindowTest(unittest.TestCase): + def test_compute_dispatch_and_window_present(self): + width = height = 8 + invocation_count = width * height + + try: + pixel_buffer = tf.createBuffer(invocation_count, 4, False) + except RuntimeError as exc: # pragma: no cover - Vulkan not available + self.skipTest(f"Vulkan buffer creation failed: {exc}") + + with ExitStack() as resources: + resources.callback(pixel_buffer.release) + + program = tf.createComputeProgramFromGLSL(_SIMPLE_GLSL, ro_count=0, rw_count=1) + resources.callback(program.release) + + program.run([], [pixel_buffer], invocation_count) + pixels = pixel_buffer.getData(np.dtype(np.uint32), invocation_count) + self.assertTrue(np.all(pixels == 0xFF3366FF), "Compute shader did not write expected color") + + try: + window = tf.createWindow(width, height, "TensorFrost Vulkan Test") + except RuntimeError as exc: # pragma: no cover - Vulkan window not available + self.skipTest(f"Vulkan window creation failed: {exc}") + + resources.callback(window.close) + + # Present once to ensure the binding path is exercised. + window.drawBuffer(pixel_buffer, width, height) + self.assertTrue(window.isOpen(), "Window should report as open after initial present") + + +if __name__ == "__main__": + unittest.main() From 63cd0caa957bfff4c23e5308aa871b2e80a4d378 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Wed, 5 Nov 2025 23:42:56 +0100 Subject: [PATCH 22/44] Use global vulkan context on backend --- TensorFrost/include/Backend/Vulkan.h | 18 ++++---- TensorFrost/include/Backend/Window.h | 2 +- TensorFrost/src/Backend/Vulkan.cpp | 41 +++++++++++++------ TensorFrost/src/Backend/Window.cpp | 3 +- .../src/Definitions/VulkanInterface.cpp | 35 +++++++--------- 5 files changed, 57 insertions(+), 42 deletions(-) diff --git a/TensorFrost/include/Backend/Vulkan.h b/TensorFrost/include/Backend/Vulkan.h index 8c222b0d..001c30f7 100644 --- a/TensorFrost/include/Backend/Vulkan.h +++ b/TensorFrost/include/Backend/Vulkan.h @@ -83,14 +83,16 @@ struct VulkanContext { ~VulkanContext(); }; -Buffer createBuffer(VulkanContext& ctx, size_t count, size_t dtypeSize, bool readOnly); -void destroyBuffer(VulkanContext& ctx, Buffer& buf); -void setBufferData(VulkanContext& ctx, Buffer& buf, const void* src, size_t bytes, size_t offset = 0); -void getBufferData(VulkanContext& ctx, const Buffer& buf, void* dst, size_t bytes, size_t offset = 0); +Buffer createBuffer(size_t count, size_t dtypeSize, bool readOnly); +void destroyBuffer(Buffer& buf); +void setBufferData(Buffer& buf, const void* src, size_t bytes, size_t offset = 0); +void getBufferData(const Buffer& buf, void* dst, size_t bytes, size_t offset = 0); -ComputeProgram createComputeProgramFromGLSL(VulkanContext& ctx, const std::string& glsl, uint32_t roCount, uint32_t rwCount); -ComputeProgram createComputeProgramFromSlang(VulkanContext& ctx, const std::string& moduleName, +ComputeProgram createComputeProgramFromGLSL(const std::string& glsl, uint32_t roCount, uint32_t rwCount); +ComputeProgram createComputeProgramFromSlang(const std::string& moduleName, const std::string& source, const std::string& entry, uint32_t roCount, uint32_t rwCount); -void destroyComputeProgram(VulkanContext& ctx, ComputeProgram& prog); +void destroyComputeProgram(ComputeProgram& prog); -void runProgram(VulkanContext& ctx, const ComputeProgram& prog, const std::vector& readonlyBuffers, const std::vector& readwriteBuffers, uint32_t n); \ No newline at end of file +void runProgram(const ComputeProgram& prog, const std::vector& readonlyBuffers, const std::vector& readwriteBuffers, uint32_t n); + +VulkanContext& getVulkanContext(); \ No newline at end of file diff --git a/TensorFrost/include/Backend/Window.h b/TensorFrost/include/Backend/Window.h index 22d26a4e..e18a5edd 100644 --- a/TensorFrost/include/Backend/Window.h +++ b/TensorFrost/include/Backend/Window.h @@ -73,7 +73,7 @@ struct WindowContext { } }; -WindowContext createWindow(VulkanContext& vctx, int width, int height, const char* title); +WindowContext createWindow(int width, int height, const char* title); bool windowOpen(const WindowContext& ctx); void drawBuffer(WindowContext& ctx, vk::Buffer src, uint32_t width, uint32_t height, vk::DeviceSize offset = 0); void drawBuffer(WindowContext& ctx, const Buffer& b, uint32_t w, uint32_t h, size_t offset = 0); diff --git a/TensorFrost/src/Backend/Vulkan.cpp b/TensorFrost/src/Backend/Vulkan.cpp index 0ab38cbb..be5a3c8c 100644 --- a/TensorFrost/src/Backend/Vulkan.cpp +++ b/TensorFrost/src/Backend/Vulkan.cpp @@ -5,6 +5,17 @@ VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE #include #include +namespace { +VulkanContext& getOrCreateGlobalContext() { + static VulkanContext ctx{}; + return ctx; +} +} + +VulkanContext& getVulkanContext() { + return getOrCreateGlobalContext(); +} + VulkanContext::VulkanContext() { if (!glfwInit()) throw std::runtime_error("GLFW init failed"); if (!glfwVulkanSupported()) @@ -234,7 +245,8 @@ static bool takeBufferFromCache(VulkanContext& ctx, size_t bytes, Buffer& out) { } // create a storage buffer -Buffer createBuffer(VulkanContext& ctx, size_t count, size_t dtypeSize, bool readOnly) { +Buffer createBuffer(size_t count, size_t dtypeSize, bool readOnly) { + auto& ctx = getVulkanContext(); Buffer buf{}; buf.size = count * dtypeSize; @@ -260,7 +272,8 @@ Buffer createBuffer(VulkanContext& ctx, size_t count, size_t dtypeSize, bool rea return buf; } -void destroyBuffer(VulkanContext& ctx, Buffer& buf) { +void destroyBuffer(Buffer& buf) { + auto& ctx = getVulkanContext(); if (!buf.buffer) return; // logical death → drop any dsCache entries that reference it invalidateDescriptorCacheForBuffer(ctx, buf.buffer); // keep if you have the DS cache; else remove this line @@ -276,7 +289,8 @@ void destroyBuffer(VulkanContext& ctx, Buffer& buf) { evictBuffersToCapacity(ctx); } -void setBufferData(VulkanContext& ctx, Buffer& buf, const void* src, size_t bytes, size_t offset) { +void setBufferData(Buffer& buf, const void* src, size_t bytes, size_t offset) { + auto& ctx = getVulkanContext(); if (offset + bytes > buf.size) throw std::out_of_range("write out of range"); auto atom = ctx.physicalDevice.getProperties().limits.nonCoherentAtomSize; vk::DeviceSize mapOff = offset - (offset % atom); @@ -290,7 +304,8 @@ void setBufferData(VulkanContext& ctx, Buffer& buf, const void* src, size_t byte ctx.device.unmapMemory(buf.memory); } -void getBufferData(VulkanContext& ctx, const Buffer& buf, void* dst, size_t bytes, size_t offset) { +void getBufferData(const Buffer& buf, void* dst, size_t bytes, size_t offset) { + auto& ctx = getVulkanContext(); if (offset + bytes > buf.size) throw std::out_of_range("read out of range"); auto atom = ctx.physicalDevice.getProperties().limits.nonCoherentAtomSize; vk::DeviceSize mapOff = offset - (offset % atom); @@ -413,9 +428,9 @@ ComputeBindings createBindings(VulkanContext& ctx, const ComputeProgram& prog, return b; } -static ComputeProgram createComputeProgram(VulkanContext& ctx, - const std::vector& spirv, +static ComputeProgram createComputeProgram(const std::vector& spirv, uint32_t roCount, uint32_t rwCount) { + auto& ctx = getVulkanContext(); ComputeProgram prog; prog.numRO = roCount; prog.numRW = rwCount; @@ -440,17 +455,18 @@ static ComputeProgram createComputeProgram(VulkanContext& ctx, return prog; } -ComputeProgram createComputeProgramFromGLSL(VulkanContext& ctx, const std::string& glsl, uint32_t roCount, uint32_t rwCount) { +ComputeProgram createComputeProgramFromGLSL(const std::string& glsl, uint32_t roCount, uint32_t rwCount) { auto spirv = compileGLSLToSpirv(glsl); - return createComputeProgram(ctx, spirv, roCount, rwCount); + return createComputeProgram(spirv, roCount, rwCount); } -ComputeProgram createComputeProgramFromSlang(VulkanContext& ctx, const std::string& moduleName, +ComputeProgram createComputeProgramFromSlang(const std::string& moduleName, const std::string& source, const std::string& entry, uint32_t roCount, uint32_t rwCount) { auto spirv = compileSlangToSpirv(moduleName.c_str(), source.c_str(), entry.c_str(), "spirv_1_5"); - return createComputeProgram(ctx, spirv, roCount, rwCount); + return createComputeProgram(spirv, roCount, rwCount); } -void destroyComputeProgram(VulkanContext& ctx, ComputeProgram& prog) { +void destroyComputeProgram(ComputeProgram& prog) { + auto& ctx = getVulkanContext(); invalidateDescriptorCacheForLayout(ctx, prog.descriptorLayout); ctx.device.destroyPipeline(prog.pipeline); ctx.device.destroyPipelineLayout(prog.pipelineLayout); @@ -459,10 +475,11 @@ void destroyComputeProgram(VulkanContext& ctx, ComputeProgram& prog) { prog = {}; } -void runProgram(VulkanContext& ctx, const ComputeProgram& prog, +void runProgram(const ComputeProgram& prog, const std::vector& readonlyBuffers, const std::vector& readwriteBuffers, uint32_t n) { + auto& ctx = getVulkanContext(); auto set = getOrCreateSet(ctx, prog, readonlyBuffers, readwriteBuffers); vk::CommandBufferAllocateInfo ai(ctx.commandPool, vk::CommandBufferLevel::ePrimary, 1); diff --git a/TensorFrost/src/Backend/Window.cpp b/TensorFrost/src/Backend/Window.cpp index 2fc4ab2e..4c06a63d 100644 --- a/TensorFrost/src/Backend/Window.cpp +++ b/TensorFrost/src/Backend/Window.cpp @@ -1,7 +1,8 @@ #include "Backend/Vulkan.h" #include "Backend/Window.h" -WindowContext createWindow(VulkanContext& vctx, int width, int height, const char* title) { +WindowContext createWindow(int width, int height, const char* title) { + auto& vctx = getVulkanContext(); if (!glfwInit()) throw std::runtime_error("glfwInit"); glfwWindowHint(GLFW_CLIENT_API, GLFW_NO_API); diff --git a/TensorFrost/src/Definitions/VulkanInterface.cpp b/TensorFrost/src/Definitions/VulkanInterface.cpp index 591b8dc6..7cfc8533 100644 --- a/TensorFrost/src/Definitions/VulkanInterface.cpp +++ b/TensorFrost/src/Definitions/VulkanInterface.cpp @@ -16,11 +16,6 @@ namespace py = pybind11; namespace TensorFrost { namespace { -VulkanContext& getContext() { - static VulkanContext ctx{}; - return ctx; -} - bool isCContiguous(const py::buffer_info& info) { py::ssize_t stride = info.itemsize; for (py::ssize_t d = info.ndim - 1; d >= 0; --d) { @@ -33,12 +28,12 @@ bool isCContiguous(const py::buffer_info& info) { } // namespace PyBuffer::PyBuffer(size_t count, size_t dtypeSize, bool readOnly) - : ctx_(&getContext()), - buffer_(createBuffer(*ctx_, count, dtypeSize, readOnly)), - readOnly_(readOnly), - dtypeSizeHint_(dtypeSize ? dtypeSize : 1), - lastCount_(count), - lastDtype_(py::none()) {} + : ctx_(&getVulkanContext()), + buffer_(createBuffer(count, dtypeSize, readOnly)), + readOnly_(readOnly), + dtypeSizeHint_(dtypeSize ? dtypeSize : 1), + lastCount_(count), + lastDtype_(py::none()) {} PyBuffer::~PyBuffer() { release(); } @@ -62,7 +57,7 @@ bool PyBuffer::isReadOnly() const { return readOnly_; } void PyBuffer::release() { if (ctx_ && buffer_.buffer) { - destroyBuffer(*ctx_, buffer_); + destroyBuffer(buffer_); } buffer_ = {}; ctx_ = nullptr; @@ -79,7 +74,7 @@ void PyBuffer::setData(const py::array& array, size_t offset) { if (offset + nbytes > buffer_.size) throw std::out_of_range("write out of range"); { py::gil_scoped_release release; - setBufferData(*ctx_, buffer_, info.ptr, nbytes, offset); + setBufferData(buffer_, info.ptr, nbytes, offset); } lastDtype_ = array.dtype(); lastCount_ = static_cast(info.size); @@ -103,7 +98,7 @@ py::array PyBuffer::getData(const py::object& dtypeArg, const py::object& countA auto info = out.request(); { py::gil_scoped_release release; - getBufferData(*ctx_, buffer_, info.ptr, nbytes, offset); + getBufferData(buffer_, info.ptr, nbytes, offset); } return out; } @@ -160,7 +155,7 @@ size_t PyBuffer::resolveCount(const py::object& countArg, size_t itemsize, size_ } PyComputeProgram::PyComputeProgram(ComputeProgram&& prog) - : ctx_(&getContext()), program_(std::move(prog)) {} + : ctx_(&getVulkanContext()), program_(std::move(prog)) {} PyComputeProgram::~PyComputeProgram() { release(); } @@ -186,12 +181,12 @@ void PyComputeProgram::run(const py::iterable& readonlyBuffers, throw std::runtime_error("buffer count does not match program layout"); } py::gil_scoped_release release; - runProgram(*ctx_, program_, ro, rw, numInvocations); + runProgram(program_, ro, rw, numInvocations); } void PyComputeProgram::release() { if (ctx_ && program_.pipeline) { - destroyComputeProgram(*ctx_, program_); + destroyComputeProgram(program_); } program_ = {}; ctx_ = nullptr; @@ -233,7 +228,7 @@ void PyComputeProgram::moveFrom(PyComputeProgram&& other) { } PyWindow::PyWindow(int width, int height, const std::string& title) - : ctx_(&getContext()), window_(createWindow(*ctx_, width, height, title.c_str())) {} + : ctx_(&getVulkanContext()), window_(createWindow(width, height, title.c_str())) {} PyWindow::~PyWindow() = default; @@ -287,7 +282,7 @@ void PyWindow::moveFrom(PyWindow&& other) { PyComputeProgram MakeComputeProgramFromGLSL(const std::string& source, uint32_t roCount, uint32_t rwCount) { - return PyComputeProgram(createComputeProgramFromGLSL(getContext(), source, roCount, rwCount)); + return PyComputeProgram(createComputeProgramFromGLSL(source, roCount, rwCount)); } PyComputeProgram MakeComputeProgramFromSlang(const std::string& moduleName, @@ -295,7 +290,7 @@ PyComputeProgram MakeComputeProgramFromSlang(const std::string& moduleName, const std::string& entry, uint32_t roCount, uint32_t rwCount) { - return PyComputeProgram(createComputeProgramFromSlang(getContext(), moduleName, source, entry, roCount, rwCount)); + return PyComputeProgram(createComputeProgramFromSlang(moduleName, source, entry, roCount, rwCount)); } } // namespace TensorFrost From c81072c42601a378b037d2729173581004c9df17 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Thu, 6 Nov 2025 01:08:27 +0100 Subject: [PATCH 23/44] Refactor project structure --- TensorFrost/Backend/CMakeLists.txt | 40 +++++++++++++++++++ .../{ => Backend}/include/Backend/Vulkan.h | 0 .../{ => Backend}/include/Backend/Window.h | 0 .../{src/Backend => Backend/src}/Vulkan.cpp | 0 .../{src/Backend => Backend/src}/Window.cpp | 0 TensorFrost/CMakeLists.txt | 25 ++++++------ TensorFrost/Compiler/CMakeLists.txt | 29 ++++++++++++++ .../{ => Compiler}/include/Compiler/Common.h | 0 .../include/Compiler/ExecutionContext.h | 0 .../include/Compiler/Operation.h | 0 .../include/Compiler/OperationArguments.h | 0 .../include/Compiler/OperationBlocks.h | 0 .../include/Compiler/OperationRegistry.h | 0 .../include/Compiler/Overloads.h | 0 .../{ => Compiler}/include/Compiler/Printer.h | 0 .../include/Compiler/TFProgram.h | 0 .../{ => Compiler}/include/Compiler/Value.h | 0 .../{src/Compiler => Compiler/src}/Common.cpp | 0 .../src}/ExecutionContext.cpp | 0 .../Compiler => Compiler/src}/Operation.cpp | 0 .../src}/OperationArguments.cpp | 0 .../src}/OperationBlocks.cpp | 0 .../src}/OperationRegistry.cpp | 0 .../Compiler => Compiler/src}/Overloads.cpp | 0 .../Compiler => Compiler/src}/Printer.cpp | 0 .../Compiler => Compiler/src}/TFProgram.cpp | 0 .../{src/Compiler => Compiler/src}/Value.cpp | 0 27 files changed, 81 insertions(+), 13 deletions(-) create mode 100644 TensorFrost/Backend/CMakeLists.txt rename TensorFrost/{ => Backend}/include/Backend/Vulkan.h (100%) rename TensorFrost/{ => Backend}/include/Backend/Window.h (100%) rename TensorFrost/{src/Backend => Backend/src}/Vulkan.cpp (100%) rename TensorFrost/{src/Backend => Backend/src}/Window.cpp (100%) create mode 100644 TensorFrost/Compiler/CMakeLists.txt rename TensorFrost/{ => Compiler}/include/Compiler/Common.h (100%) rename TensorFrost/{ => Compiler}/include/Compiler/ExecutionContext.h (100%) rename TensorFrost/{ => Compiler}/include/Compiler/Operation.h (100%) rename TensorFrost/{ => Compiler}/include/Compiler/OperationArguments.h (100%) rename TensorFrost/{ => Compiler}/include/Compiler/OperationBlocks.h (100%) rename TensorFrost/{ => Compiler}/include/Compiler/OperationRegistry.h (100%) rename TensorFrost/{ => Compiler}/include/Compiler/Overloads.h (100%) rename TensorFrost/{ => Compiler}/include/Compiler/Printer.h (100%) rename TensorFrost/{ => Compiler}/include/Compiler/TFProgram.h (100%) rename TensorFrost/{ => Compiler}/include/Compiler/Value.h (100%) rename TensorFrost/{src/Compiler => Compiler/src}/Common.cpp (100%) rename TensorFrost/{src/Compiler => Compiler/src}/ExecutionContext.cpp (100%) rename TensorFrost/{src/Compiler => Compiler/src}/Operation.cpp (100%) rename TensorFrost/{src/Compiler => Compiler/src}/OperationArguments.cpp (100%) rename TensorFrost/{src/Compiler => Compiler/src}/OperationBlocks.cpp (100%) rename TensorFrost/{src/Compiler => Compiler/src}/OperationRegistry.cpp (100%) rename TensorFrost/{src/Compiler => Compiler/src}/Overloads.cpp (100%) rename TensorFrost/{src/Compiler => Compiler/src}/Printer.cpp (100%) rename TensorFrost/{src/Compiler => Compiler/src}/TFProgram.cpp (100%) rename TensorFrost/{src/Compiler => Compiler/src}/Value.cpp (100%) diff --git a/TensorFrost/Backend/CMakeLists.txt b/TensorFrost/Backend/CMakeLists.txt new file mode 100644 index 00000000..e4e3648d --- /dev/null +++ b/TensorFrost/Backend/CMakeLists.txt @@ -0,0 +1,40 @@ +set(TF_BACKEND_INC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include) +set(TF_BACKEND_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src) + +file(GLOB_RECURSE TENSORFROST_BACKEND_SOURCE_LIST CONFIGURE_DEPENDS + ${TF_BACKEND_SRC_DIR}/*.cpp) + +file(GLOB_RECURSE TENSORFROST_BACKEND_HEADER_LIST CONFIGURE_DEPENDS + ${TF_BACKEND_INC_DIR}/*.h + ${TF_BACKEND_INC_DIR}/*.hpp) + +add_library(TensorFrostBackend STATIC + ${TENSORFROST_BACKEND_SOURCE_LIST} + ${TENSORFROST_BACKEND_HEADER_LIST}) + +target_include_directories(TensorFrostBackend + PUBLIC + ${TF_BACKEND_INC_DIR} + $ENV{VULKAN_SDK}/Include) + +target_link_libraries(TensorFrostBackend + PUBLIC + Vulkan::Vulkan + Vulkan::shaderc_combined + glfw + $<$:${SLANG_LIB_DEBUG}> + $<$>:${SLANG_LIB_RELEASE}>) + +target_compile_definitions(TensorFrostBackend PUBLIC VULKAN_HPP_DISPATCH_LOADER_DYNAMIC=1) + +target_compile_features(TensorFrostBackend PUBLIC cxx_std_20) + +if (MSVC) + target_compile_options(TensorFrostBackend PRIVATE /wd4804 /wd4805 /wd4018) +endif() + +source_group(TREE ${TF_BACKEND_SRC_DIR} PREFIX "Source Files" + FILES ${TENSORFROST_BACKEND_SOURCE_LIST}) + +source_group(TREE ${TF_BACKEND_INC_DIR} PREFIX "Header Files" + FILES ${TENSORFROST_BACKEND_HEADER_LIST}) diff --git a/TensorFrost/include/Backend/Vulkan.h b/TensorFrost/Backend/include/Backend/Vulkan.h similarity index 100% rename from TensorFrost/include/Backend/Vulkan.h rename to TensorFrost/Backend/include/Backend/Vulkan.h diff --git a/TensorFrost/include/Backend/Window.h b/TensorFrost/Backend/include/Backend/Window.h similarity index 100% rename from TensorFrost/include/Backend/Window.h rename to TensorFrost/Backend/include/Backend/Window.h diff --git a/TensorFrost/src/Backend/Vulkan.cpp b/TensorFrost/Backend/src/Vulkan.cpp similarity index 100% rename from TensorFrost/src/Backend/Vulkan.cpp rename to TensorFrost/Backend/src/Vulkan.cpp diff --git a/TensorFrost/src/Backend/Window.cpp b/TensorFrost/Backend/src/Window.cpp similarity index 100% rename from TensorFrost/src/Backend/Window.cpp rename to TensorFrost/Backend/src/Window.cpp diff --git a/TensorFrost/CMakeLists.txt b/TensorFrost/CMakeLists.txt index f6743c9d..35024978 100644 --- a/TensorFrost/CMakeLists.txt +++ b/TensorFrost/CMakeLists.txt @@ -1,6 +1,16 @@ set(TF_INC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include) set(TF_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src) +find_library(SLANG_LIB_RELEASE NAMES slang PATHS $ENV{VULKAN_SDK}/Lib) +find_library(SLANG_LIB_DEBUG NAMES slangd slang PATHS $ENV{VULKAN_SDK}/Lib) + +if (NOT SLANG_LIB_RELEASE) + message(FATAL_ERROR "slang.lib not found in %VULKAN_SDK%/Lib") +endif() + +add_subdirectory(Compiler) +add_subdirectory(Backend) + file(GLOB_RECURSE TENSORFROST_SOURCE_LIST CONFIGURE_DEPENDS ${TF_SRC_DIR}/*.cpp) @@ -32,20 +42,9 @@ if(APPLE) ) endif() -find_library(SLANG_LIB_RELEASE NAMES slang PATHS $ENV{VULKAN_SDK}/Lib) -find_library(SLANG_LIB_DEBUG NAMES slangd slang PATHS $ENV{VULKAN_SDK}/Lib) - -if (NOT SLANG_LIB_RELEASE) - message(FATAL_ERROR "slang.lib not found in %VULKAN_SDK%/Lib") -endif() - -target_compile_definitions(TensorFrost PRIVATE VULKAN_HPP_DISPATCH_LOADER_DYNAMIC=1) -target_link_libraries(TensorFrost PRIVATE Vulkan::Vulkan Vulkan::shaderc_combined glfw) target_link_libraries(TensorFrost PRIVATE - $<$:${SLANG_LIB_DEBUG}> - $<$>:${SLANG_LIB_RELEASE}> -) -target_include_directories(TensorFrost PRIVATE $ENV{VULKAN_SDK}/Include) + TensorFrostCompiler + TensorFrostBackend) if (MSVC) target_compile_options(TensorFrost PRIVATE /wd4804 /wd4805 /wd4018) diff --git a/TensorFrost/Compiler/CMakeLists.txt b/TensorFrost/Compiler/CMakeLists.txt new file mode 100644 index 00000000..be053e5f --- /dev/null +++ b/TensorFrost/Compiler/CMakeLists.txt @@ -0,0 +1,29 @@ +set(TF_COMPILER_INC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include) +set(TF_COMPILER_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src) + +file(GLOB_RECURSE TENSORFROST_COMPILER_SOURCE_LIST CONFIGURE_DEPENDS + ${TF_COMPILER_SRC_DIR}/*.cpp) + +file(GLOB_RECURSE TENSORFROST_COMPILER_HEADER_LIST CONFIGURE_DEPENDS + ${TF_COMPILER_INC_DIR}/*.h + ${TF_COMPILER_INC_DIR}/*.hpp) + +add_library(TensorFrostCompiler STATIC + ${TENSORFROST_COMPILER_SOURCE_LIST} + ${TENSORFROST_COMPILER_HEADER_LIST}) + +target_include_directories(TensorFrostCompiler + PUBLIC + ${TF_COMPILER_INC_DIR}) + +target_compile_features(TensorFrostCompiler PUBLIC cxx_std_20) + +if (MSVC) + target_compile_options(TensorFrostCompiler PRIVATE /wd4804 /wd4805 /wd4018) +endif() + +source_group(TREE ${TF_COMPILER_SRC_DIR} PREFIX "Source Files" + FILES ${TENSORFROST_COMPILER_SOURCE_LIST}) + +source_group(TREE ${TF_COMPILER_INC_DIR} PREFIX "Header Files" + FILES ${TENSORFROST_COMPILER_HEADER_LIST}) diff --git a/TensorFrost/include/Compiler/Common.h b/TensorFrost/Compiler/include/Compiler/Common.h similarity index 100% rename from TensorFrost/include/Compiler/Common.h rename to TensorFrost/Compiler/include/Compiler/Common.h diff --git a/TensorFrost/include/Compiler/ExecutionContext.h b/TensorFrost/Compiler/include/Compiler/ExecutionContext.h similarity index 100% rename from TensorFrost/include/Compiler/ExecutionContext.h rename to TensorFrost/Compiler/include/Compiler/ExecutionContext.h diff --git a/TensorFrost/include/Compiler/Operation.h b/TensorFrost/Compiler/include/Compiler/Operation.h similarity index 100% rename from TensorFrost/include/Compiler/Operation.h rename to TensorFrost/Compiler/include/Compiler/Operation.h diff --git a/TensorFrost/include/Compiler/OperationArguments.h b/TensorFrost/Compiler/include/Compiler/OperationArguments.h similarity index 100% rename from TensorFrost/include/Compiler/OperationArguments.h rename to TensorFrost/Compiler/include/Compiler/OperationArguments.h diff --git a/TensorFrost/include/Compiler/OperationBlocks.h b/TensorFrost/Compiler/include/Compiler/OperationBlocks.h similarity index 100% rename from TensorFrost/include/Compiler/OperationBlocks.h rename to TensorFrost/Compiler/include/Compiler/OperationBlocks.h diff --git a/TensorFrost/include/Compiler/OperationRegistry.h b/TensorFrost/Compiler/include/Compiler/OperationRegistry.h similarity index 100% rename from TensorFrost/include/Compiler/OperationRegistry.h rename to TensorFrost/Compiler/include/Compiler/OperationRegistry.h diff --git a/TensorFrost/include/Compiler/Overloads.h b/TensorFrost/Compiler/include/Compiler/Overloads.h similarity index 100% rename from TensorFrost/include/Compiler/Overloads.h rename to TensorFrost/Compiler/include/Compiler/Overloads.h diff --git a/TensorFrost/include/Compiler/Printer.h b/TensorFrost/Compiler/include/Compiler/Printer.h similarity index 100% rename from TensorFrost/include/Compiler/Printer.h rename to TensorFrost/Compiler/include/Compiler/Printer.h diff --git a/TensorFrost/include/Compiler/TFProgram.h b/TensorFrost/Compiler/include/Compiler/TFProgram.h similarity index 100% rename from TensorFrost/include/Compiler/TFProgram.h rename to TensorFrost/Compiler/include/Compiler/TFProgram.h diff --git a/TensorFrost/include/Compiler/Value.h b/TensorFrost/Compiler/include/Compiler/Value.h similarity index 100% rename from TensorFrost/include/Compiler/Value.h rename to TensorFrost/Compiler/include/Compiler/Value.h diff --git a/TensorFrost/src/Compiler/Common.cpp b/TensorFrost/Compiler/src/Common.cpp similarity index 100% rename from TensorFrost/src/Compiler/Common.cpp rename to TensorFrost/Compiler/src/Common.cpp diff --git a/TensorFrost/src/Compiler/ExecutionContext.cpp b/TensorFrost/Compiler/src/ExecutionContext.cpp similarity index 100% rename from TensorFrost/src/Compiler/ExecutionContext.cpp rename to TensorFrost/Compiler/src/ExecutionContext.cpp diff --git a/TensorFrost/src/Compiler/Operation.cpp b/TensorFrost/Compiler/src/Operation.cpp similarity index 100% rename from TensorFrost/src/Compiler/Operation.cpp rename to TensorFrost/Compiler/src/Operation.cpp diff --git a/TensorFrost/src/Compiler/OperationArguments.cpp b/TensorFrost/Compiler/src/OperationArguments.cpp similarity index 100% rename from TensorFrost/src/Compiler/OperationArguments.cpp rename to TensorFrost/Compiler/src/OperationArguments.cpp diff --git a/TensorFrost/src/Compiler/OperationBlocks.cpp b/TensorFrost/Compiler/src/OperationBlocks.cpp similarity index 100% rename from TensorFrost/src/Compiler/OperationBlocks.cpp rename to TensorFrost/Compiler/src/OperationBlocks.cpp diff --git a/TensorFrost/src/Compiler/OperationRegistry.cpp b/TensorFrost/Compiler/src/OperationRegistry.cpp similarity index 100% rename from TensorFrost/src/Compiler/OperationRegistry.cpp rename to TensorFrost/Compiler/src/OperationRegistry.cpp diff --git a/TensorFrost/src/Compiler/Overloads.cpp b/TensorFrost/Compiler/src/Overloads.cpp similarity index 100% rename from TensorFrost/src/Compiler/Overloads.cpp rename to TensorFrost/Compiler/src/Overloads.cpp diff --git a/TensorFrost/src/Compiler/Printer.cpp b/TensorFrost/Compiler/src/Printer.cpp similarity index 100% rename from TensorFrost/src/Compiler/Printer.cpp rename to TensorFrost/Compiler/src/Printer.cpp diff --git a/TensorFrost/src/Compiler/TFProgram.cpp b/TensorFrost/Compiler/src/TFProgram.cpp similarity index 100% rename from TensorFrost/src/Compiler/TFProgram.cpp rename to TensorFrost/Compiler/src/TFProgram.cpp diff --git a/TensorFrost/src/Compiler/Value.cpp b/TensorFrost/Compiler/src/Value.cpp similarity index 100% rename from TensorFrost/src/Compiler/Value.cpp rename to TensorFrost/Compiler/src/Value.cpp From c260a24aa4434a26838baf68d00ca131420f3273 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Thu, 6 Nov 2025 01:17:28 +0100 Subject: [PATCH 24/44] Comment out of date tests --- tests/autograd_test.py | 6 +++++- tests/linalg_test.py | 16 ++++++++++------ tests/reshape_reduction_test.py | 6 +++++- tests/sorting_opengl_test.py | 5 ++++- tests/sorting_test.py | 5 ++++- tests/split_dim_test.py | 5 ++++- 6 files changed, 32 insertions(+), 11 deletions(-) diff --git a/tests/autograd_test.py b/tests/autograd_test.py index 46c66e61..c8c4f0c9 100644 --- a/tests/autograd_test.py +++ b/tests/autograd_test.py @@ -1,3 +1,5 @@ +"""Autotest temporarily disabled pending updates. + import TensorFrost as tf import numpy as np import torch @@ -234,4 +236,6 @@ def test_autograd(self): self.assertTrue(np.allclose(yhat_tf.numpy, yhat.detach().numpy(), atol=1e-5)) for i, param in enumerate(model_torch.parameters()): self.assertTrue(np.allclose(tf_grads.grad[i].numpy, param.grad.detach().numpy(), atol=1e-3)) - self.assertTrue(np.allclose(tf_grads.net.parameters()[i].numpy, param.detach().numpy(), atol=1e-3)) \ No newline at end of file + self.assertTrue(np.allclose(tf_grads.net.parameters()[i].numpy, param.detach().numpy(), atol=1e-3)) +""" + diff --git a/tests/linalg_test.py b/tests/linalg_test.py index e40ab565..8ea11f74 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -1,3 +1,5 @@ +"""Autotest temporarily disabled pending updates. + import numpy as np import TensorFrost as tf import unittest @@ -104,10 +106,12 @@ def test_qr_inversion(self): norm_error = np.linalg.norm(np.dot(Q, R) - np.dot(Qnp, Rnp)) print("QR decomposition error: ", norm_error) self.assertTrue(norm_error < 1e-5) - norm_error = np.linalg.norm(Rinv - Rinvtf) - print("Triangular matrix inversion error: ", norm_error) - self.assertTrue(norm_error < 5e-3) - norm_error = np.linalg.norm(Ainv - Ainvtf) - print("Matrix inversion error: ", norm_error) - self.assertTrue(norm_error < 5e-3) + norm_error = np.linalg.norm(Rinv - Rinvtf) + print("Triangular matrix inversion error: ", norm_error) + self.assertTrue(norm_error < 5e-3) + norm_error = np.linalg.norm(Ainv - Ainvtf) + print("Matrix inversion error: ", norm_error) + self.assertTrue(norm_error < 5e-3) +""" + diff --git a/tests/reshape_reduction_test.py b/tests/reshape_reduction_test.py index c2ffc3d0..54ff724e 100644 --- a/tests/reshape_reduction_test.py +++ b/tests/reshape_reduction_test.py @@ -1,3 +1,5 @@ +"""Autotest temporarily disabled pending updates. + import TensorFrost as tf import numpy as np import unittest @@ -46,4 +48,6 @@ def test_reduction_reshape(self): self.assertTrue(np.allclose(norm_tf.numpy, norm_np)) self.assertTrue(np.allclose(total_max_tf.numpy, total_max_np)) self.assertTrue(np.allclose(total_min_tf.numpy, total_min_np)) - self.assertTrue(np.allclose(mean_block_tf.numpy, mean_block_np)) \ No newline at end of file + self.assertTrue(np.allclose(mean_block_tf.numpy, mean_block_np)) +""" + diff --git a/tests/sorting_opengl_test.py b/tests/sorting_opengl_test.py index 8cdb31e8..25e94f29 100644 --- a/tests/sorting_opengl_test.py +++ b/tests/sorting_opengl_test.py @@ -1,3 +1,5 @@ +"""Autotest temporarily disabled pending updates. + # %% import numpy as np import TensorFrost as tf @@ -68,4 +70,5 @@ def test_sorting(self): # check if the results are the same error_radix = np.sum(np.abs(sorted_keys0.numpy - sorted_keys2)) print("Radix float errors: ", error_radix) - self.assertTrue(error_radix == 0) \ No newline at end of file + self.assertTrue(error_radix == 0) +""" diff --git a/tests/sorting_test.py b/tests/sorting_test.py index 2fa4a3d0..623f86f9 100644 --- a/tests/sorting_test.py +++ b/tests/sorting_test.py @@ -1,3 +1,5 @@ +"""Autotest temporarily disabled pending updates. + # %% import numpy as np import TensorFrost as tf @@ -68,4 +70,5 @@ def test_sorting(self): # check if the results are the same error_radix = np.sum(np.abs(sorted_keys0.numpy - sorted_keys2)) print("Radix float errors: ", error_radix) - self.assertTrue(error_radix == 0) \ No newline at end of file + self.assertTrue(error_radix == 0) +""" diff --git a/tests/split_dim_test.py b/tests/split_dim_test.py index 12d05270..211fd699 100644 --- a/tests/split_dim_test.py +++ b/tests/split_dim_test.py @@ -1,3 +1,5 @@ +"""Autotest temporarily disabled pending updates. + import numpy as np import TensorFrost as tf import unittest @@ -21,4 +23,5 @@ def test_split_dim(self): print(merged.shape) self.assertTrue(merged.shape == [128, 128, 32]) print(splitted.shape) - self.assertTrue(splitted.shape == [4, 32, 128, 32]) \ No newline at end of file + self.assertTrue(splitted.shape == [4, 32, 128, 32]) +""" From c5b48a2a6f84a438689d5c7848a6404bba7a1a0f Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Thu, 6 Nov 2025 04:14:01 +0100 Subject: [PATCH 25/44] vulkan imgui integration --- TensorFrost/Backend/CMakeLists.txt | 4 +- TensorFrost/Backend/include/Backend/Window.h | 44 +++ TensorFrost/Backend/src/Window.cpp | 334 +++++++++++++++++- TensorFrost/Compiler/CMakeLists.txt | 6 + TensorFrost/PybindModule.cpp | 2 - .../src/Definitions/VulkanBindings.cpp | 40 ++- .../src/Definitions/VulkanInterface.cpp | 108 ++++++ TensorFrost/src/Definitions/VulkanInterface.h | 26 ++ TensorFrost/src/Definitions/WindowUtils.cpp | 148 -------- examples/Slang/mandelbrot.py | 131 +++++-- tests/imgui_test.py | 76 ++++ 11 files changed, 727 insertions(+), 192 deletions(-) delete mode 100644 TensorFrost/src/Definitions/WindowUtils.cpp create mode 100644 tests/imgui_test.py diff --git a/TensorFrost/Backend/CMakeLists.txt b/TensorFrost/Backend/CMakeLists.txt index e4e3648d..4df5c05e 100644 --- a/TensorFrost/Backend/CMakeLists.txt +++ b/TensorFrost/Backend/CMakeLists.txt @@ -15,7 +15,9 @@ add_library(TensorFrostBackend STATIC target_include_directories(TensorFrostBackend PUBLIC ${TF_BACKEND_INC_DIR} - $ENV{VULKAN_SDK}/Include) + $ENV{VULKAN_SDK}/Include + ${CMAKE_SOURCE_DIR}/external/imgui + ${CMAKE_SOURCE_DIR}/external/imgui/backends) target_link_libraries(TensorFrostBackend PUBLIC diff --git a/TensorFrost/Backend/include/Backend/Window.h b/TensorFrost/Backend/include/Backend/Window.h index e18a5edd..9dd742f7 100644 --- a/TensorFrost/Backend/include/Backend/Window.h +++ b/TensorFrost/Backend/include/Backend/Window.h @@ -3,6 +3,12 @@ #include #include #include +#include + + +struct WindowContext; +void ReleaseImGui(WindowContext& ctx); +struct ImGuiContext; struct WindowContext { GLFWwindow* wnd{}; @@ -22,6 +28,14 @@ struct WindowContext { vk::Semaphore semImage, semDone; vk::Fence fence; + // ImGui integration helpers + vk::DescriptorPool imguiPool{}; + vk::RenderPass renderPass{}; + std::vector imageViews; + std::vector framebuffers; + ImGuiContext* imguiContext{}; + bool imguiFrameActive = false; + WindowContext() = default; WindowContext(const WindowContext&) = delete; WindowContext& operator=(const WindowContext&) = delete; @@ -52,13 +66,27 @@ struct WindowContext { semImage=o.semImage; o.semImage=nullptr; semDone=o.semDone; o.semDone=nullptr; fence=o.fence; o.fence=nullptr; + imguiPool=o.imguiPool; o.imguiPool=nullptr; + renderPass=o.renderPass; o.renderPass=nullptr; + imageViews=std::move(o.imageViews); + framebuffers=std::move(o.framebuffers); + imguiContext=o.imguiContext; o.imguiContext=nullptr; + imguiFrameActive=o.imguiFrameActive; o.imguiFrameActive=false; } void cleanup() { if (!wnd && !device) return; // already moved/clean // don’t terminate GLFW here; only destroy this window if (device) { + ReleaseImGui(*this); + (void)device.waitIdle(); + for (auto fb : framebuffers) device.destroyFramebuffer(fb); + framebuffers.clear(); + for (auto view : imageViews) device.destroyImageView(view); + imageViews.clear(); + if (imguiPool) { device.destroyDescriptorPool(imguiPool); imguiPool=nullptr; } + if (renderPass) { device.destroyRenderPass(renderPass); renderPass=nullptr; } if (fence) device.destroyFence(fence), fence=nullptr; if (semDone) device.destroySemaphore(semDone), semDone=nullptr; if (semImage) device.destroySemaphore(semImage), semImage=nullptr; @@ -77,3 +105,19 @@ WindowContext createWindow(int width, int height, const char* title); bool windowOpen(const WindowContext& ctx); void drawBuffer(WindowContext& ctx, vk::Buffer src, uint32_t width, uint32_t height, vk::DeviceSize offset = 0); void drawBuffer(WindowContext& ctx, const Buffer& b, uint32_t w, uint32_t h, size_t offset = 0); + +// Global helpers used by higher-level integrations / Python bindings +WindowContext* GetWindow(); +WindowContext& RequireWindow(); +ImGuiContext* GetImGuiContext(); +void EnsureImGuiFrame(WindowContext& ctx); + +void ShowWindow(int width, int height, const char* title); +void HideWindow(); +void RenderFrame(const Buffer* buffer, uint32_t width, uint32_t height, size_t offset = 0); +void RenderFrame(const Buffer* buffer = nullptr); +bool WindowShouldClose(); +std::pair GetMousePosition(); +std::pair GetWindowSize(); +bool IsMouseButtonPressed(int button); +bool IsKeyPressed(int key); diff --git a/TensorFrost/Backend/src/Window.cpp b/TensorFrost/Backend/src/Window.cpp index 4c06a63d..797e0ae0 100644 --- a/TensorFrost/Backend/src/Window.cpp +++ b/TensorFrost/Backend/src/Window.cpp @@ -1,6 +1,171 @@ #include "Backend/Vulkan.h" #include "Backend/Window.h" +#include +#include +#include + +#include +#include +#include +#include + +namespace { +std::unique_ptr gWindow; +std::mutex gWindowMutex; + +void CheckVkResult(VkResult err) { + if (err == VK_SUCCESS) return; + throw std::runtime_error("ImGui Vulkan backend error: " + std::to_string(err)); +} + +void EnsureFramebuffers(WindowContext& ctx) { + if (ctx.framebuffers.size() == ctx.images.size() && !ctx.framebuffers.empty()) return; + + if (!ctx.renderPass) { + vk::AttachmentDescription colorAttachment{}; + colorAttachment.format = ctx.format; + colorAttachment.samples = vk::SampleCountFlagBits::e1; + colorAttachment.loadOp = vk::AttachmentLoadOp::eLoad; + colorAttachment.storeOp = vk::AttachmentStoreOp::eStore; + colorAttachment.stencilLoadOp = vk::AttachmentLoadOp::eDontCare; + colorAttachment.stencilStoreOp = vk::AttachmentStoreOp::eDontCare; + colorAttachment.initialLayout = vk::ImageLayout::eColorAttachmentOptimal; + colorAttachment.finalLayout = vk::ImageLayout::ePresentSrcKHR; + + vk::AttachmentReference colorAttachmentRef{0, vk::ImageLayout::eColorAttachmentOptimal}; + + vk::SubpassDescription subpass{}; + subpass.pipelineBindPoint = vk::PipelineBindPoint::eGraphics; + subpass.colorAttachmentCount = 1; + subpass.pColorAttachments = &colorAttachmentRef; + + vk::SubpassDependency dependency{}; + dependency.srcSubpass = VK_SUBPASS_EXTERNAL; + dependency.dstSubpass = 0; + dependency.srcStageMask = vk::PipelineStageFlagBits::eColorAttachmentOutput; + dependency.dstStageMask = vk::PipelineStageFlagBits::eColorAttachmentOutput; + dependency.srcAccessMask = {}; + dependency.dstAccessMask = vk::AccessFlagBits::eColorAttachmentWrite; + + vk::RenderPassCreateInfo rpci{}; + rpci.attachmentCount = 1; + rpci.pAttachments = &colorAttachment; + rpci.subpassCount = 1; + rpci.pSubpasses = &subpass; + rpci.dependencyCount = 1; + rpci.pDependencies = &dependency; + + ctx.renderPass = ctx.device.createRenderPass(rpci); + } + + if (ctx.imageViews.size() != ctx.images.size()) { + for (auto view : ctx.imageViews) ctx.device.destroyImageView(view); + ctx.imageViews.clear(); + ctx.imageViews.reserve(ctx.images.size()); + for (auto image : ctx.images) { + vk::ImageViewCreateInfo viewInfo{}; + viewInfo.image = image; + viewInfo.viewType = vk::ImageViewType::e2D; + viewInfo.format = ctx.format; + viewInfo.components = vk::ComponentMapping(); + viewInfo.subresourceRange = {vk::ImageAspectFlagBits::eColor, 0, 1, 0, 1}; + ctx.imageViews.push_back(ctx.device.createImageView(viewInfo)); + } + } + + for (auto fb : ctx.framebuffers) ctx.device.destroyFramebuffer(fb); + ctx.framebuffers.clear(); + ctx.framebuffers.reserve(ctx.imageViews.size()); + for (auto view : ctx.imageViews) { + vk::FramebufferCreateInfo fbci{}; + fbci.renderPass = ctx.renderPass; + fbci.attachmentCount = 1; + fbci.pAttachments = &view; + fbci.width = ctx.extent.width; + fbci.height = ctx.extent.height; + fbci.layers = 1; + ctx.framebuffers.push_back(ctx.device.createFramebuffer(fbci)); + } +} + +void EnsureImGui(WindowContext& ctx, uint32_t imageCount) { + if (ctx.imguiContext) return; + + EnsureFramebuffers(ctx); + + std::array poolSizes = { + vk::DescriptorPoolSize{vk::DescriptorType::eSampler, 1000}, + vk::DescriptorPoolSize{vk::DescriptorType::eCombinedImageSampler, 1000}, + vk::DescriptorPoolSize{vk::DescriptorType::eSampledImage, 1000}, + vk::DescriptorPoolSize{vk::DescriptorType::eStorageImage, 1000}, + vk::DescriptorPoolSize{vk::DescriptorType::eUniformTexelBuffer, 1000}, + vk::DescriptorPoolSize{vk::DescriptorType::eStorageTexelBuffer, 1000}, + vk::DescriptorPoolSize{vk::DescriptorType::eUniformBuffer, 1000}, + vk::DescriptorPoolSize{vk::DescriptorType::eStorageBuffer, 1000}, + vk::DescriptorPoolSize{vk::DescriptorType::eUniformBufferDynamic, 1000}, + vk::DescriptorPoolSize{vk::DescriptorType::eStorageBufferDynamic, 1000}, + vk::DescriptorPoolSize{vk::DescriptorType::eInputAttachment, 1000} + }; + + vk::DescriptorPoolCreateInfo poolInfo(vk::DescriptorPoolCreateFlagBits::eFreeDescriptorSet, + 1000 * static_cast(poolSizes.size()), + static_cast(poolSizes.size()), + poolSizes.data()); + ctx.imguiPool = ctx.device.createDescriptorPool(poolInfo); + + IMGUI_CHECKVERSION(); + ctx.imguiContext = ImGui::CreateContext(); + ImGui::SetCurrentContext(ctx.imguiContext); + ImGui::StyleColorsDark(); + + ImGui_ImplGlfw_InitForVulkan(ctx.wnd, true); + + ImGui_ImplVulkan_InitInfo initInfo{}; + initInfo.Instance = ctx.instance; + initInfo.PhysicalDevice = ctx.phys; + initInfo.Device = ctx.device; + initInfo.QueueFamily = ctx.presentFam; + initInfo.Queue = ctx.queue; + initInfo.Subpass = 0; + initInfo.MinImageCount = imageCount; + initInfo.ImageCount = imageCount; + initInfo.MSAASamples = VK_SAMPLE_COUNT_1_BIT; + initInfo.Allocator = nullptr; + initInfo.PipelineCache = VK_NULL_HANDLE; + initInfo.DescriptorPool = ctx.imguiPool; + initInfo.CheckVkResultFn = CheckVkResult; + initInfo.RenderPass = static_cast(ctx.renderPass); + + if (!ImGui_ImplVulkan_Init(&initInfo)) { + throw std::runtime_error("ImGui_ImplVulkan_Init failed"); + } + + if (!ImGui_ImplVulkan_CreateFontsTexture()) { + throw std::runtime_error("ImGui_ImplVulkan_CreateFontsTexture failed"); + } +} + +void StartImGuiFrame(WindowContext& ctx) { + if (!ctx.imguiContext || ctx.imguiFrameActive) return; + ImGui::SetCurrentContext(ctx.imguiContext); + ImGui_ImplVulkan_NewFrame(); + ImGui_ImplGlfw_NewFrame(); + ImGui::NewFrame(); + ctx.imguiFrameActive = true; +} +} // namespace + +void ReleaseImGui(WindowContext& ctx) { + if (!ctx.imguiContext) return; + ImGui::SetCurrentContext(ctx.imguiContext); + ImGui_ImplVulkan_Shutdown(); + ImGui_ImplGlfw_Shutdown(); + ImGui::DestroyContext(ctx.imguiContext); + ctx.imguiContext = nullptr; + ctx.imguiFrameActive = false; +} + WindowContext createWindow(int width, int height, const char* title) { auto& vctx = getVulkanContext(); if (!glfwInit()) throw std::runtime_error("glfwInit"); @@ -63,6 +228,10 @@ WindowContext createWindow(int width, int height, const char* title) { ctx.semImage = ctx.device.createSemaphore({}); ctx.semDone = ctx.device.createSemaphore({}); ctx.fence = ctx.device.createFence({}); + + EnsureImGui(ctx, imageCount); + StartImGuiFrame(ctx); + return ctx; } @@ -81,29 +250,76 @@ void drawBuffer(WindowContext &ctx, vk::Buffer src, uint32_t width, uint32_t hei ctx.cmd.reset({}); ctx.cmd.begin({vk::CommandBufferUsageFlagBits::eOneTimeSubmit}); - vk::ImageMemoryBarrier toDst({}, vk::AccessFlagBits::eTransferWrite, - vk::ImageLayout::eUndefined, vk::ImageLayout::eTransferDstOptimal, - VK_QUEUE_FAMILY_IGNORED, VK_QUEUE_FAMILY_IGNORED, - ctx.images[idx], {vk::ImageAspectFlagBits::eColor, 0,1, 0,1}); + vk::ImageSubresourceRange range{vk::ImageAspectFlagBits::eColor, 0, 1, 0, 1}; + + vk::ImageMemoryBarrier toTransfer({}, vk::AccessFlagBits::eTransferWrite, + vk::ImageLayout::eUndefined, vk::ImageLayout::eTransferDstOptimal, + VK_QUEUE_FAMILY_IGNORED, VK_QUEUE_FAMILY_IGNORED, + ctx.images[idx], range); ctx.cmd.pipelineBarrier(vk::PipelineStageFlagBits::eTopOfPipe, - vk::PipelineStageFlagBits::eTransfer, {}, {}, {}, toDst); + vk::PipelineStageFlagBits::eTransfer, + {}, nullptr, nullptr, toTransfer); + + if (src) { + vk::BufferImageCopy copy{}; + copy.bufferOffset = offset; + copy.imageSubresource = {vk::ImageAspectFlagBits::eColor, 0, 0, 1}; + copy.imageExtent = vk::Extent3D{ width, height, 1 }; + ctx.cmd.copyBufferToImage(src, ctx.images[idx], vk::ImageLayout::eTransferDstOptimal, 1, ©); + } - vk::BufferImageCopy copy{}; - copy.bufferOffset = offset; - copy.imageSubresource = {vk::ImageAspectFlagBits::eColor, 0, 0, 1}; - copy.imageExtent = vk::Extent3D{ width, height, 1 }; - ctx.cmd.copyBufferToImage(src, ctx.images[idx], vk::ImageLayout::eTransferDstOptimal, 1, ©); + vk::ImageMemoryBarrier toColor(src ? vk::AccessFlagBits::eTransferWrite : vk::AccessFlagBits{}, + vk::AccessFlagBits::eColorAttachmentWrite, + src ? vk::ImageLayout::eTransferDstOptimal : vk::ImageLayout::eUndefined, + vk::ImageLayout::eColorAttachmentOptimal, + VK_QUEUE_FAMILY_IGNORED, VK_QUEUE_FAMILY_IGNORED, + ctx.images[idx], range); + ctx.cmd.pipelineBarrier(src ? vk::PipelineStageFlagBits::eTransfer : vk::PipelineStageFlagBits::eTopOfPipe, + vk::PipelineStageFlagBits::eColorAttachmentOutput, + {}, nullptr, nullptr, toColor); - vk::ImageMemoryBarrier toPresent(vk::AccessFlagBits::eTransferWrite, {}, - vk::ImageLayout::eTransferDstOptimal, vk::ImageLayout::ePresentSrcKHR, + EnsureFramebuffers(ctx); + + if (!src) { + vk::ClearColorValue clearColor(std::array{0.f, 0.f, 0.f, 1.f}); + std::array ranges{range}; + ctx.cmd.clearColorImage(ctx.images[idx], vk::ImageLayout::eColorAttachmentOptimal, clearColor, ranges); + } + + vk::RenderPassBeginInfo rpBegin{}; + rpBegin.renderPass = ctx.renderPass; + rpBegin.framebuffer = ctx.framebuffers[idx]; + rpBegin.renderArea.offset = vk::Offset2D{0, 0}; + rpBegin.renderArea.extent = ctx.extent; + vk::ClearValue clearValue{}; + rpBegin.clearValueCount = 1; + rpBegin.pClearValues = &clearValue; + + ctx.cmd.beginRenderPass(rpBegin, vk::SubpassContents::eInline); + + if (ctx.imguiContext && ctx.imguiFrameActive) { + ImGui::SetCurrentContext(ctx.imguiContext); + ImGui::Render(); + ImDrawData* drawData = ImGui::GetDrawData(); + if (drawData && drawData->CmdListsCount > 0) { + ImGui_ImplVulkan_RenderDrawData(drawData, static_cast(ctx.cmd)); + } + ctx.imguiFrameActive = false; + } + + ctx.cmd.endRenderPass(); + + vk::ImageMemoryBarrier toPresent(vk::AccessFlagBits::eColorAttachmentWrite, {}, + vk::ImageLayout::eColorAttachmentOptimal, vk::ImageLayout::ePresentSrcKHR, VK_QUEUE_FAMILY_IGNORED, VK_QUEUE_FAMILY_IGNORED, - ctx.images[idx], {vk::ImageAspectFlagBits::eColor, 0,1, 0,1}); - ctx.cmd.pipelineBarrier(vk::PipelineStageFlagBits::eTransfer, - vk::PipelineStageFlagBits::eBottomOfPipe, {}, {}, {}, toPresent); + ctx.images[idx], range); + ctx.cmd.pipelineBarrier(vk::PipelineStageFlagBits::eColorAttachmentOutput, + vk::PipelineStageFlagBits::eBottomOfPipe, + {}, nullptr, nullptr, toPresent); ctx.cmd.end(); - vk::PipelineStageFlags waitStage = vk::PipelineStageFlagBits::eTransfer; + vk::PipelineStageFlags waitStage = vk::PipelineStageFlagBits::eColorAttachmentOutput; (void)ctx.device.resetFences(ctx.fence); ctx.queue.submit({vk::SubmitInfo(1, &ctx.semImage, &waitStage, 1, &ctx.cmd, 1, &ctx.semDone)}, ctx.fence); @@ -114,6 +330,8 @@ void drawBuffer(WindowContext &ctx, vk::Buffer src, uint32_t width, uint32_t hei } catch (const vk::OutOfDateKHRError&) { // ignore } + + StartImGuiFrame(ctx); } void drawBuffer(WindowContext &ctx, const Buffer &b, uint32_t w, uint32_t h, size_t offset) { @@ -121,3 +339,87 @@ void drawBuffer(WindowContext &ctx, const Buffer &b, uint32_t w, uint32_t h, siz if (offset + size_t(w)*size_t(h)*4 > b.size) throw std::out_of_range("buffer too small"); drawBuffer(ctx, b.buffer, w, h, offset); } + +WindowContext* GetWindow() { + std::scoped_lock lock(gWindowMutex); + return gWindow.get(); +} + +WindowContext& RequireWindow() { + auto* wnd = GetWindow(); + if (!wnd) throw std::runtime_error("Window not created. Call ShowWindow() first."); + return *wnd; +} + +ImGuiContext* GetImGuiContext() { + auto* wnd = GetWindow(); + return wnd ? wnd->imguiContext : nullptr; +} + +void EnsureImGuiFrame(WindowContext& ctx) { + StartImGuiFrame(ctx); +} + +void ShowWindow(int width, int height, const char* title) { + std::scoped_lock lock(gWindowMutex); + if (gWindow && windowOpen(*gWindow)) return; + gWindow = std::make_unique(createWindow(width, height, title)); +} + +void HideWindow() { + std::scoped_lock lock(gWindowMutex); + gWindow.reset(); +} + +void RenderFrame(const Buffer* buffer, uint32_t width, uint32_t height, size_t offset) { + auto& ctx = RequireWindow(); + EnsureImGuiFrame(ctx); + + uint32_t w = width ? width : ctx.extent.width; + uint32_t h = height ? height : ctx.extent.height; + + if (buffer) { + drawBuffer(ctx, buffer->buffer, w, h, offset); + } else { + drawBuffer(ctx, vk::Buffer{}, w, h, offset); + } +} + +void RenderFrame(const Buffer* buffer) { + auto& ctx = RequireWindow(); + RenderFrame(buffer, ctx.extent.width, ctx.extent.height, 0); +} + +bool WindowShouldClose() { + std::scoped_lock lock(gWindowMutex); + if (!gWindow) return true; + return !windowOpen(*gWindow); +} + +std::pair GetMousePosition() { + std::scoped_lock lock(gWindowMutex); + if (!gWindow || !gWindow->wnd) return {0.0, 0.0}; + double x{}, y{}; + glfwGetCursorPos(gWindow->wnd, &x, &y); + return {x, y}; +} + +std::pair GetWindowSize() { + std::scoped_lock lock(gWindowMutex); + if (!gWindow || !gWindow->wnd) return {0, 0}; + int w{}, h{}; + glfwGetWindowSize(gWindow->wnd, &w, &h); + return {w, h}; +} + +bool IsMouseButtonPressed(int button) { + std::scoped_lock lock(gWindowMutex); + if (!gWindow || !gWindow->wnd) return false; + return glfwGetMouseButton(gWindow->wnd, button) == GLFW_PRESS; +} + +bool IsKeyPressed(int key) { + std::scoped_lock lock(gWindowMutex); + if (!gWindow || !gWindow->wnd) return false; + return glfwGetKey(gWindow->wnd, key) == GLFW_PRESS; +} diff --git a/TensorFrost/Compiler/CMakeLists.txt b/TensorFrost/Compiler/CMakeLists.txt index be053e5f..5dccaf30 100644 --- a/TensorFrost/Compiler/CMakeLists.txt +++ b/TensorFrost/Compiler/CMakeLists.txt @@ -20,6 +20,12 @@ target_compile_features(TensorFrostCompiler PUBLIC cxx_std_20) if (MSVC) target_compile_options(TensorFrostCompiler PRIVATE /wd4804 /wd4805 /wd4018) + target_compile_definitions(TensorFrostCompiler PRIVATE + $<$:_ITERATOR_DEBUG_LEVEL=2> + $<$:_HAS_ITERATOR_DEBUGGING=1> + ) + set_property(TARGET TensorFrostCompiler PROPERTY + MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>DLL") endif() source_group(TREE ${TF_COMPILER_SRC_DIR} PREFIX "Source Files" diff --git a/TensorFrost/PybindModule.cpp b/TensorFrost/PybindModule.cpp index dc4331ea..253d1d23 100644 --- a/TensorFrost/PybindModule.cpp +++ b/TensorFrost/PybindModule.cpp @@ -21,7 +21,6 @@ namespace TensorFrost { // void TensorProgramDefinition(py::module&, py::class_&); // void TensorMemoryDefinition(py::module& m, // py::class_& py_tensor_mem); -// void WindowDefinitions(py::module& m); // void ScopeDefinitions(py::module& m, py::class_& py_tensor); // void ModuleDefinitions(py::module& m); @@ -102,7 +101,6 @@ PYBIND11_MODULE(TensorFrost, m) { // TensorFunctionsDefinition(m); // TensorProgramDefinition(m, tensor_program); // TensorMemoryDefinition(m, py_tensor_mem); - // WindowDefinitions(m); // ScopeDefinitions(m, py_tensor); // ModuleDefinitions(m); // diff --git a/TensorFrost/src/Definitions/VulkanBindings.cpp b/TensorFrost/src/Definitions/VulkanBindings.cpp index 2ffacd24..0a642a05 100644 --- a/TensorFrost/src/Definitions/VulkanBindings.cpp +++ b/TensorFrost/src/Definitions/VulkanBindings.cpp @@ -1,8 +1,11 @@ #include "Definitions/VulkanBindings.h" #include "VulkanInterface.h" +#include + #include #include +#include namespace py = pybind11; @@ -79,8 +82,43 @@ void VulkanDefinitions(py::module_& m) { .def("drawBuffer", &PyWindow::drawBuffer, py::arg("buffer"), py::arg("width"), py::arg("height"), py::arg("offset") = 0, "Copy a buffer of packed pixels onto the swapchain.") + .def("present", &PyWindow::present, + "Present the current frame without uploading new pixels.") .def("close", &PyWindow::close, - "Destroy the window and release its swapchain resources."); + "Destroy the window and release its swapchain resources.") + .def("imgui_begin", &PyWindow::imguiBegin, + py::arg("name"), py::arg("open") = py::none(), py::arg("flags") = 0, + "Begin a new ImGui window, returning (visible, open_flag_or_None).") + .def("imgui_end", &PyWindow::imguiEnd, + "End the current ImGui window.") + .def("imgui_text", &PyWindow::imguiText, + py::arg("text"), + "Add text to the current ImGui window.") + .def("imgui_button", &PyWindow::imguiButton, + py::arg("label"), + "Add a button and return True when pressed.") + .def("imgui_checkbox", &PyWindow::imguiCheckbox, + py::arg("label"), py::arg("value"), + "Add a checkbox and return the updated value.") + .def("imgui_slider_int", &PyWindow::imguiSliderInt, + py::arg("label"), py::arg("value"), py::arg("min"), py::arg("max"), + "Slider that returns the updated integer value.") + .def("imgui_slider_float", &PyWindow::imguiSliderFloat, + py::arg("label"), py::arg("value"), py::arg("min"), py::arg("max"), + "Slider that returns the updated float value.") + .def("imgui_plot_lines", &PyWindow::imguiPlotLines, + py::arg("label"), py::arg("values"), py::arg("values_offset") = 0, + py::arg("overlay_text") = "", py::arg("scale_min") = FLT_MAX, + py::arg("scale_max") = FLT_MAX, + py::arg("graph_size") = py::make_tuple(0.0f, 0.0f), + py::arg("stride") = sizeof(float), + "Plot a sequence of values as lines.") + .def("imgui_scale_all_sizes", &PyWindow::imguiScaleAllSizes, + py::arg("scale"), + "Scale all ImGui sizes by a factor.") + .def("imgui_add_background_text", &PyWindow::imguiAddBackgroundText, + py::arg("text"), py::arg("pos"), py::arg("color"), + "Draw text in the background draw list."); m.def("createWindow", [](int width, int height, const std::string& title) { diff --git a/TensorFrost/src/Definitions/VulkanInterface.cpp b/TensorFrost/src/Definitions/VulkanInterface.cpp index 7cfc8533..ce5e20d8 100644 --- a/TensorFrost/src/Definitions/VulkanInterface.cpp +++ b/TensorFrost/src/Definitions/VulkanInterface.cpp @@ -8,6 +8,8 @@ #include #include +#include + #include "Backend/Vulkan.h" #include "Backend/Window.h" @@ -252,6 +254,12 @@ void PyWindow::drawBuffer(const PyBuffer& buffer, uint32_t width, uint32_t heigh ::drawBuffer(window_, buffer.raw(), width, height, offset); } +void PyWindow::present() { + ensureValid(); + py::gil_scoped_release release; + ::drawBuffer(window_, vk::Buffer{}, window_.extent.width, window_.extent.height, 0); +} + py::tuple PyWindow::size() const { ensureValid(); return py::make_tuple(window_.extent.width, window_.extent.height); @@ -267,6 +275,90 @@ void PyWindow::close() { ctx_ = nullptr; } +py::tuple PyWindow::imguiBegin(const std::string& name, + const std::optional& open, + int flags) { + bindImGui(); + bool openValue = open.value_or(true); + bool visible = ImGui::Begin(name.c_str(), open ? &openValue : nullptr, flags); + return py::make_tuple(visible, open ? py::cast(openValue) : py::none()); +} + +void PyWindow::imguiEnd() { + bindImGui(); + ImGui::End(); +} + +void PyWindow::imguiText(const std::string& text) { + bindImGui(); + ImGui::TextUnformatted(text.c_str()); +} + +bool PyWindow::imguiButton(const std::string& label) { + bindImGui(); + return ImGui::Button(label.c_str()); +} + +bool PyWindow::imguiCheckbox(const std::string& label, bool value) { + bindImGui(); + bool v = value; + ImGui::Checkbox(label.c_str(), &v); + return v; +} + +int PyWindow::imguiSliderInt(const std::string& label, int value, int min, int max) { + bindImGui(); + int v = value; + ImGui::SliderInt(label.c_str(), &v, min, max); + return v; +} + +float PyWindow::imguiSliderFloat(const std::string& label, float value, float min, float max) { + bindImGui(); + float v = value; + ImGui::SliderFloat(label.c_str(), &v, min, max); + return v; +} + +void PyWindow::imguiPlotLines(const std::string& label, + py::array_t values, + int valuesOffset, + const std::string& overlayText, + float scaleMin, + float scaleMax, + py::tuple graphSize, + int stride) { + bindImGui(); + validateTupleSize(graphSize, 2, "graph_size"); + ImGui::PlotLines( + label.c_str(), + values.data(), + static_cast(values.size()), + valuesOffset, + overlayText.empty() ? nullptr : overlayText.c_str(), + scaleMin, + scaleMax, + ImVec2(graphSize[0].cast(), graphSize[1].cast()), + stride); +} + +void PyWindow::imguiScaleAllSizes(float scale) { + bindImGui(); + ImGui::GetStyle().ScaleAllSizes(scale); +} + +void PyWindow::imguiAddBackgroundText(const std::string& text, + py::tuple pos, + py::tuple color) { + bindImGui(); + validateTupleSize(pos, 2, "pos"); + validateTupleSize(color, 4, "color"); + ImGui::GetBackgroundDrawList()->AddText( + ImVec2(pos[0].cast(), pos[1].cast()), + ImColor(color[0].cast(), color[1].cast(), color[2].cast(), color[3].cast()), + text.c_str()); +} + void PyWindow::ensureValid() const { if (!window_.wnd) { throw std::runtime_error("Window has been closed"); @@ -279,6 +371,22 @@ void PyWindow::moveFrom(PyWindow&& other) { other.ctx_ = nullptr; } +ImGuiContext* PyWindow::bindImGui() { + ensureValid(); + EnsureImGuiFrame(window_); + if (!window_.imguiContext) { + throw std::runtime_error("ImGui context is not initialized for this window"); + } + ImGui::SetCurrentContext(window_.imguiContext); + return window_.imguiContext; +} + +void PyWindow::validateTupleSize(const py::tuple& tpl, size_t expected, const char* name) { + if (tpl.size() != expected) { + throw std::invalid_argument(std::string("Expected tuple of size ") + std::to_string(expected) + " for " + name); + } +} + PyComputeProgram MakeComputeProgramFromGLSL(const std::string& source, uint32_t roCount, uint32_t rwCount) { diff --git a/TensorFrost/src/Definitions/VulkanInterface.h b/TensorFrost/src/Definitions/VulkanInterface.h index b172f25c..7c0441b2 100644 --- a/TensorFrost/src/Definitions/VulkanInterface.h +++ b/TensorFrost/src/Definitions/VulkanInterface.h @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -98,13 +99,38 @@ class PyWindow { bool isOpen() const; void drawBuffer(const PyBuffer& buffer, uint32_t width, uint32_t height, size_t offset); + void present(); pybind11::tuple size() const; int format() const; void close(); + pybind11::tuple imguiBegin(const std::string& name, + const std::optional& open, + int flags); + void imguiEnd(); + void imguiText(const std::string& text); + bool imguiButton(const std::string& label); + bool imguiCheckbox(const std::string& label, bool value); + int imguiSliderInt(const std::string& label, int value, int min, int max); + float imguiSliderFloat(const std::string& label, float value, float min, float max); + void imguiPlotLines(const std::string& label, + pybind11::array_t values, + int valuesOffset, + const std::string& overlayText, + float scaleMin, + float scaleMax, + pybind11::tuple graphSize, + int stride); + void imguiScaleAllSizes(float scale); + void imguiAddBackgroundText(const std::string& text, + pybind11::tuple pos, + pybind11::tuple color); + private: void ensureValid() const; void moveFrom(PyWindow&& other); + ImGuiContext* bindImGui(); + static void validateTupleSize(const pybind11::tuple& tpl, size_t expected, const char* name); VulkanContext* ctx_{}; WindowContext window_{}; diff --git a/TensorFrost/src/Definitions/WindowUtils.cpp b/TensorFrost/src/Definitions/WindowUtils.cpp deleted file mode 100644 index 5bbf4647..00000000 --- a/TensorFrost/src/Definitions/WindowUtils.cpp +++ /dev/null @@ -1,148 +0,0 @@ -// #include -// #include -// -// #include -// #include -// -// #include "Backend/RenderDoc.h" -// -// namespace TensorFrost { -// -// void WindowDefinitions(py::module& m) { -// py::module window = m.def_submodule("window", "Window functions"); -// -// window.def( -// "show", -// [](int width, int height, string title) { -// ShowWindow(width, height, title.c_str()); -// }, -// "Show the memory manager window"); -// -// window.def( -// "hide", []() { HideWindow(); }, "Hide the memory manager window"); -// -// window.def( -// "render_frame", [](const PyTensorMemory& t) { RenderFrame(t.tensor_); }, -// "Render a frame from the tensor memory"); -// -// window.def("render_frame", []() { RenderFrame(nullptr); }, -// "Render an empty frame"); -// -// window.def( -// "should_close", []() { return WindowShouldClose(); }, -// "Check if the window should close"); -// -// window.def( -// "get_mouse_position", []() { return GetMousePosition(); }, -// "Get the current mouse position"); -// -// window.def( -// "get_size", []() { return GetWindowSize(); }, -// "Get the current window size"); -// -// window.def( -// "is_mouse_button_pressed", -// [](int button) { return IsMouseButtonPressed(button); }, -// "Check if a mouse button is pressed"); -// -// window.def( -// "is_key_pressed", [](int key) { return IsKeyPressed(key); }, -// "Check if a key is pressed"); -// -// window.attr("MOUSE_BUTTON_0") = GLFW_MOUSE_BUTTON_1; -// window.attr("MOUSE_BUTTON_1") = GLFW_MOUSE_BUTTON_2; -// window.attr("MOUSE_BUTTON_2") = GLFW_MOUSE_BUTTON_3; -// -// window.attr("KEY_SPACE") = GLFW_KEY_SPACE; -// window.attr("KEY_APOSTROPHE") = GLFW_KEY_APOSTROPHE; -// window.attr("KEY_COMMA") = GLFW_KEY_COMMA; -// window.attr("KEY_MINUS") = GLFW_KEY_MINUS; -// window.attr("KEY_PERIOD") = GLFW_KEY_PERIOD; -// -// window.attr("KEY_A") = GLFW_KEY_A; -// window.attr("KEY_B") = GLFW_KEY_B; -// window.attr("KEY_C") = GLFW_KEY_C; -// window.attr("KEY_D") = GLFW_KEY_D; -// window.attr("KEY_E") = GLFW_KEY_E; -// window.attr("KEY_F") = GLFW_KEY_F; -// window.attr("KEY_G") = GLFW_KEY_G; -// window.attr("KEY_H") = GLFW_KEY_H; -// window.attr("KEY_I") = GLFW_KEY_I; -// window.attr("KEY_J") = GLFW_KEY_J; -// window.attr("KEY_K") = GLFW_KEY_K; -// window.attr("KEY_L") = GLFW_KEY_L; -// window.attr("KEY_M") = GLFW_KEY_M; -// window.attr("KEY_N") = GLFW_KEY_N; -// window.attr("KEY_O") = GLFW_KEY_O; -// window.attr("KEY_P") = GLFW_KEY_P; -// window.attr("KEY_Q") = GLFW_KEY_Q; -// window.attr("KEY_R") = GLFW_KEY_R; -// window.attr("KEY_S") = GLFW_KEY_S; -// window.attr("KEY_T") = GLFW_KEY_T; -// window.attr("KEY_U") = GLFW_KEY_U; -// window.attr("KEY_V") = GLFW_KEY_V; -// window.attr("KEY_W") = GLFW_KEY_W; -// window.attr("KEY_X") = GLFW_KEY_X; -// window.attr("KEY_Y") = GLFW_KEY_Y; -// window.attr("KEY_Z") = GLFW_KEY_Z; -// -// py::module imgui = m.def_submodule("imgui", "ImGui functions"); -// -// imgui.def( -// "begin", [](string name) { ImGuiBegin(name); }, -// "Begin a new ImGui window"); -// -// imgui.def("end", []() { ImGuiEnd(); }, "End the current ImGui window"); -// -// imgui.def( -// "text", [](string text) { ImGuiText(text); }, -// "Add text to the current ImGui window"); -// -// imgui.def("slider", -// [](string text, int* value, int min, int max) { -// ImGuiSlider(text, value, min, max); -// return *value; -// }, "Add a slider to the current ImGui window"); -// -// imgui.def("slider", [](string text, float* value, float min, float max) { -// ImGuiSlider(text, value, min, max); -// return *value; -// }, "Add a slider to the current ImGui window"); -// -// imgui.def("button", [](string text) { return ImGuiButton(text); }, -// "Add a button to the current ImGui window"); -// -// imgui.def("checkbox", [](string text, bool* value) { -// ImGuiCheckbox(text, value); -// return *value; -// }, "Add a checkbox to the current ImGui window"); -// -// imgui.def("plotlines", [](string label, py::array_t values, int values_offset, string overlay_text, float scale_min, float scale_max, py::tuple graph_size, int stride) { -// ImGuiPlotLines(label.c_str(), values.data(), (int)values.size(), values_offset, overlay_text.c_str(), scale_min, scale_max, ImVec2(graph_size[0].cast(), graph_size[1].cast()), stride); -// }, py::arg("label"), py::arg("values"), py::arg("values_offset") = 0, py::arg("overlay_text") = "", py::arg("scale_min") = FLT_MAX, py::arg("scale_max") = FLT_MAX, py::arg("graph_size") = py::make_tuple(0.0f, 0.0f), py::arg("stride") = sizeof(float)); -// -// imgui.def("scale_all_sizes", [](float scale) { ImGuiScaleAllSizes(scale); }, -// "Scale all ImGui sizes by a factor"); -// -// imgui.def("add_background_text", [](string text, py::tuple pos, py::tuple color) { -// ImGuiAddBackgroundText(text, ImVec2(pos[0].cast(), pos[1].cast()), ImVec4(color[0].cast(), color[1].cast(), color[2].cast(), color[3].cast())); -// }, py::arg("text"), py::arg("pos"), py::arg("color")); -// -// imgui.def("color_picker3", [](string text, py::array_t color) { -// ImGuiColorPicker3(text, color.mutable_data()); -// return color; -// }, py::arg("text"), py::arg("color")); -// -// imgui.def("color_picker4", [](string text, py::array_t color) { -// ImGuiColorPicker4(text, color.mutable_data()); -// return color; -// }, py::arg("text"), py::arg("color")); -// -// //renderdoc -// m.def("renderdoc_start_capture", []() { StartRenderDocCapture(); }, -// "Start a RenderDoc capture"); -// m.def("renderdoc_end_capture", []() { EndRenderDocCapture(); }, -// "End a RenderDoc capture"); -// } -// -// } // namespace TensorFrost \ No newline at end of file diff --git a/examples/Slang/mandelbrot.py b/examples/Slang/mandelbrot.py index 2da3e2d0..1d5e62ee 100644 --- a/examples/Slang/mandelbrot.py +++ b/examples/Slang/mandelbrot.py @@ -1,33 +1,116 @@ +from pathlib import Path +import math + import numpy as np import TensorFrost as tf -from pathlib import Path -W, H = 1024, 768 -win = tf.createWindow(W, H, "Mandelbrot") -fmt = int(win.format) -is_bgra = fmt in (44, 50) # VK_FORMAT_B8G8R8A8_UNORM / _SRGB -pix = tf.createBuffer(W*H, 4, False) # uint32 pixels -params = tf.createBuffer(8, 4, True) # 8 float32 params +def load_shader() -> str: + with open(Path(__file__).with_name("mandelbrot.slang"), "r", encoding="utf-8") as handle: + return handle.read() + + +def main() -> None: + width, height = 1024, 768 + win = tf.createWindow(width, height, "Mandelbrot (ImGui)") + + fmt = int(win.format) + is_bgra_default = fmt in (44, 50) # VK_FORMAT_B8G8R8A8_UNORM / _SRGB + + pixel_capacity = max(1, width * height) + pixel_buffer = tf.createBuffer(pixel_capacity, 4, False) + params_buffer = tf.createBuffer(8, 4, True) + + shader_source = load_shader() + program = tf.createComputeProgramFromSlang("mandelbrot", shader_source, "csMain", ro_count=1, rw_count=1) + + center = [-0.5, 0.0] + scale = 3.0 + log_scale = math.log10(scale) + manual_iterations = 500 + auto_iterations = True + swap_rb = is_bgra_default + plot_history = np.zeros(120, dtype=np.float32) + history_index = 0 + + params = np.zeros(8, dtype=np.float32) + + def ensure_pixel_buffer(cur_width: int, cur_height: int) -> None: + nonlocal pixel_buffer, pixel_capacity + required = max(1, cur_width * cur_height) + if required != pixel_capacity: + pixel_buffer.release() + pixel_buffer = tf.createBuffer(required, 4, False) + pixel_capacity = required + + try: + while win.isOpen(): + width, height = win.size + width = max(1, int(width)) + height = max(1, int(height)) + + ensure_pixel_buffer(width, height) + + aspect = height / float(width) + + visible, _ = win.imgui_begin("Mandelbrot Controls", open=True) + if visible: + win.imgui_text(f"Resolution: {width} × {height}") + win.imgui_text(f"Format: {fmt}") + log_scale = win.imgui_slider_float("log₁₀ scale", log_scale, -4.5, 0.5) + scale = pow(10.0, log_scale) + center[0] = win.imgui_slider_float("Center X", center[0], -2.5, 1.5) + center[1] = win.imgui_slider_float("Center Y", center[1], -1.5, 1.5) + auto_iterations = win.imgui_checkbox("Auto iterations", auto_iterations) + if auto_iterations: + auto_value = max(64, int(200 + (-math.log10(scale)) * 120)) + manual_iterations = auto_value + win.imgui_text(f"Max iterations (auto): {auto_value}") + else: + manual_iterations = win.imgui_slider_int("Max iterations", manual_iterations, 64, 5000) + swap_rb = win.imgui_checkbox("BGRA swap", swap_rb) + if win.imgui_button("Reset view"): + center[:] = (-0.5, 0.0) + scale = 3.0 + log_scale = math.log10(scale) + manual_iterations = 500 + auto_iterations = True + + plot_history[history_index % plot_history.size] = float(manual_iterations) + history_index += 1 + win.imgui_plot_lines( + "Iteration history", + plot_history, + values_offset=history_index % plot_history.size, + graph_size=(0.0, 60.0), + ) + win.imgui_end() + + xspan = scale + yspan = xspan * aspect + xmin = center[0] - xspan * 0.5 + ymin = center[1] - yspan * 0.5 + dx = xspan / width + dy = yspan / height -with open(Path(__file__).with_name('mandelbrot.slang'), 'r') as f: - hlsl = f.read() + params[0] = float(width) + params[1] = float(height) + params[2] = xmin + params[3] = ymin + params[4] = dx + params[5] = dy + params[6] = float(manual_iterations) + params[7] = 1.0 if swap_rb else 0.0 -prog = tf.createComputeProgramFromSlang("mandelbrot", hlsl, "csMain", ro_count=1, rw_count=1) + params_buffer.setData(params) + program.run([params_buffer], [pixel_buffer], width * height) -# view rectangle with aspect correction -xspan = 3.0 -yspan = xspan * (H / float(W)) -xmin, ymin = -2.0, -yspan * 0.5 -dx, dy = xspan / W, yspan / H -max_iter = 500.0 + win.drawBuffer(pixel_buffer, width, height) + finally: + win.close() + pixel_buffer.release() + params_buffer.release() -p = np.array([float(W), float(H), xmin, ymin, dx, dy, max_iter, 1.0 if is_bgra else 0.0], dtype=np.float32) -params.setData(p) -try: - while win.isOpen(): - prog.run([params], [pix], W*H) - win.drawBuffer(pix, W, H) -finally: - win.close() +if __name__ == "__main__": + main() diff --git a/tests/imgui_test.py b/tests/imgui_test.py new file mode 100644 index 00000000..2c150139 --- /dev/null +++ b/tests/imgui_test.py @@ -0,0 +1,76 @@ +import unittest +from contextlib import contextmanager + +import numpy as np +import TensorFrost as tf + + +def _should_skip_for_backend(exc: RuntimeError) -> bool: + message = str(exc).lower() + keywords = ( + "not initialized", + "glfw", + "surface", + "no suitable", + "unavailable", + ) + return any(token in message for token in keywords) + + +@contextmanager +def managed_window(width=320, height=240, title="ImGui Test Window"): + try: + win = tf.createWindow(width, height, title) + except RuntimeError as exc: + if _should_skip_for_backend(exc): + raise unittest.SkipTest(f"Window backend unavailable: {exc}") from exc + raise + try: + yield win + finally: + try: + win.close() + except Exception: + pass + + +class ImGuiIntegrationTest(unittest.TestCase): + def test_imgui_basic_widgets(self): + with managed_window() as win: + visible, open_flag = win.imgui_begin("Main Panel") + self.assertTrue(visible, "ImGui window should be visible on begin") + self.assertIsNone(open_flag, "Default begin call should return None for open flag") + + win.imgui_text("Hello from TensorFrost tests") + + button_pressed = win.imgui_button("Press Me") + self.assertIn(button_pressed, (True, False)) + + self.assertTrue(win.imgui_checkbox("Checkbox", True)) + + slider_value = win.imgui_slider_float("Float Slider", 0.25, 0.0, 1.0) + self.assertGreaterEqual(slider_value, 0.0) + self.assertLessEqual(slider_value, 1.0) + + data = np.linspace(0.0, 1.0, 32, dtype=np.float32) + win.imgui_plot_lines("Plot", data, overlay_text="test", graph_size=(120.0, 40.0)) + + win.imgui_add_background_text("BG", (10.0, 10.0), (1.0, 1.0, 1.0, 1.0)) + win.imgui_scale_all_sizes(1.0) + win.imgui_end() + + visible_secondary, open_flag_secondary = win.imgui_begin("Secondary", open=True) + self.assertTrue(visible_secondary) + self.assertIsInstance(open_flag_secondary, bool) + win.imgui_text("Secondary window contents") + updated_int = win.imgui_slider_int("Int Slider", 5, 0, 10) + self.assertGreaterEqual(updated_int, 0) + self.assertLessEqual(updated_int, 10) + win.imgui_end() + + win.present() + self.assertIsInstance(win.isOpen(), bool) + + +if __name__ == "__main__": + unittest.main() From b5c6dbdc44dd86b9c6d480e20dc3f577d8aa6bad Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Thu, 6 Nov 2025 04:44:28 +0100 Subject: [PATCH 26/44] Add window resizing support --- TensorFrost/Backend/src/Window.cpp | 150 ++++++++++++++++-- .../src/Definitions/VulkanInterface.cpp | 9 +- TensorFrost/src/Definitions/VulkanInterface.h | 2 +- 3 files changed, 144 insertions(+), 17 deletions(-) diff --git a/TensorFrost/Backend/src/Window.cpp b/TensorFrost/Backend/src/Window.cpp index 797e0ae0..88de44c1 100644 --- a/TensorFrost/Backend/src/Window.cpp +++ b/TensorFrost/Backend/src/Window.cpp @@ -154,6 +154,92 @@ void StartImGuiFrame(WindowContext& ctx) { ImGui::NewFrame(); ctx.imguiFrameActive = true; } + +void DestroySwapchainViews(WindowContext& ctx) { + for (auto fb : ctx.framebuffers) ctx.device.destroyFramebuffer(fb); + ctx.framebuffers.clear(); + for (auto view : ctx.imageViews) ctx.device.destroyImageView(view); + ctx.imageViews.clear(); +} + +vk::SurfaceFormatKHR SelectSurfaceFormat(const WindowContext& ctx, + const std::vector& formats) { + if (formats.empty()) { + return {vk::Format::eB8G8R8A8Unorm, vk::ColorSpaceKHR::eSrgbNonlinear}; + } + for (const auto& fmt : formats) { + if (fmt.format == ctx.format) { + return fmt; + } + } + return formats.front(); +} + +void RecreateSwapchain(WindowContext& ctx, vk::Extent2D desiredExtent) { + if (!ctx.wnd) return; + + ctx.device.waitIdle(); + + DestroySwapchainViews(ctx); + + auto caps = ctx.phys.getSurfaceCapabilitiesKHR(ctx.surface); + auto fmts = ctx.phys.getSurfaceFormatsKHR(ctx.surface); + if (!(caps.supportedUsageFlags & vk::ImageUsageFlagBits::eTransferDst)) { + throw std::runtime_error("swapchain missing TRANSFER_DST"); + } + + vk::SurfaceFormatKHR surfaceFormat = SelectSurfaceFormat(ctx, fmts); + + vk::Extent2D extent{}; + if (caps.currentExtent.width != UINT32_MAX) { + extent = caps.currentExtent; + } else { + extent.width = static_cast(std::clamp(static_cast(desiredExtent.width), + static_cast(caps.minImageExtent.width), + static_cast(caps.maxImageExtent.width))); + extent.height = static_cast(std::clamp(static_cast(desiredExtent.height), + static_cast(caps.minImageExtent.height), + static_cast(caps.maxImageExtent.height))); + } + + if (extent.width == 0 || extent.height == 0) { + ctx.extent = extent; + return; + } + + uint32_t imageCount = std::max(caps.minImageCount, 2u); + if (caps.maxImageCount) imageCount = std::min(imageCount, caps.maxImageCount); + + vk::SwapchainCreateInfoKHR createInfo{}; + createInfo.surface = ctx.surface; + createInfo.minImageCount = imageCount; + createInfo.imageFormat = surfaceFormat.format; + createInfo.imageColorSpace = surfaceFormat.colorSpace; + createInfo.imageExtent = extent; + createInfo.imageArrayLayers = 1; + createInfo.imageUsage = vk::ImageUsageFlagBits::eTransferDst; + createInfo.imageSharingMode = vk::SharingMode::eExclusive; + createInfo.preTransform = caps.currentTransform; + createInfo.compositeAlpha = vk::CompositeAlphaFlagBitsKHR::eOpaque; + createInfo.presentMode = vk::PresentModeKHR::eFifo; + createInfo.clipped = VK_TRUE; + createInfo.oldSwapchain = ctx.swapchain; + + vk::SwapchainKHR newSwapchain = ctx.device.createSwapchainKHR(createInfo); + if (ctx.swapchain) { + ctx.device.destroySwapchainKHR(ctx.swapchain); + } + ctx.swapchain = newSwapchain; + ctx.images = ctx.device.getSwapchainImagesKHR(ctx.swapchain); + ctx.extent = extent; + ctx.format = surfaceFormat.format; + + EnsureFramebuffers(ctx); + + if (ctx.imguiContext) { + ImGui_ImplVulkan_SetMinImageCount(imageCount); + } +} } // namespace void ReleaseImGui(WindowContext& ctx) { @@ -240,12 +326,39 @@ bool windowOpen(const WindowContext &ctx) { } void drawBuffer(WindowContext &ctx, vk::Buffer src, uint32_t width, uint32_t height, vk::DeviceSize offset) { + if (!ctx.wnd) return; + glfwPollEvents(); - auto acq = ctx.device.acquireNextImageKHR(ctx.swapchain, UINT64_MAX, ctx.semImage, {}); - if (acq.result == vk::Result::eSuboptimalKHR) {} // continue - if (acq.result == vk::Result::eErrorOutOfDateKHR) return; // ignore; recreate swapchain if you want - uint32_t idx = acq.value; + uint32_t idx = 0; + while (true) { + int fbw = 0; + int fbh = 0; + glfwGetFramebufferSize(ctx.wnd, &fbw, &fbh); + if (fbw <= 0 || fbh <= 0) { + if (ctx.imguiContext && ctx.imguiFrameActive) { + ImGui::SetCurrentContext(ctx.imguiContext); + ImGui::Render(); + ctx.imguiFrameActive = false; + } + return; + } + + vk::Extent2D desiredExtent{static_cast(fbw), static_cast(fbh)}; + if (desiredExtent.width != ctx.extent.width || desiredExtent.height != ctx.extent.height) { + RecreateSwapchain(ctx, desiredExtent); + continue; + } + + auto acq = ctx.device.acquireNextImageKHR(ctx.swapchain, UINT64_MAX, ctx.semImage, {}); + if (acq.result == vk::Result::eErrorOutOfDateKHR || acq.result == vk::Result::eSuboptimalKHR) { + RecreateSwapchain(ctx, desiredExtent); + continue; + } + + idx = acq.value; + break; + } ctx.cmd.reset({}); ctx.cmd.begin({vk::CommandBufferUsageFlagBits::eOneTimeSubmit}); @@ -260,32 +373,39 @@ void drawBuffer(WindowContext &ctx, vk::Buffer src, uint32_t width, uint32_t hei vk::PipelineStageFlagBits::eTransfer, {}, nullptr, nullptr, toTransfer); - if (src) { + uint32_t copyWidth = std::min(width, ctx.extent.width); + uint32_t copyHeight = std::min(height, ctx.extent.height); + + bool performedTransfer = false; + + if (!src || copyWidth != ctx.extent.width || copyHeight != ctx.extent.height) { + vk::ClearColorValue clearColor(std::array{0.f, 0.f, 0.f, 1.f}); + std::array ranges{range}; + ctx.cmd.clearColorImage(ctx.images[idx], vk::ImageLayout::eTransferDstOptimal, clearColor, ranges); + performedTransfer = true; + } + + if (src && copyWidth > 0 && copyHeight > 0) { vk::BufferImageCopy copy{}; copy.bufferOffset = offset; copy.imageSubresource = {vk::ImageAspectFlagBits::eColor, 0, 0, 1}; - copy.imageExtent = vk::Extent3D{ width, height, 1 }; + copy.imageExtent = vk::Extent3D{copyWidth, copyHeight, 1}; ctx.cmd.copyBufferToImage(src, ctx.images[idx], vk::ImageLayout::eTransferDstOptimal, 1, ©); + performedTransfer = true; } - vk::ImageMemoryBarrier toColor(src ? vk::AccessFlagBits::eTransferWrite : vk::AccessFlagBits{}, + vk::ImageMemoryBarrier toColor(performedTransfer ? vk::AccessFlagBits::eTransferWrite : vk::AccessFlags{}, vk::AccessFlagBits::eColorAttachmentWrite, - src ? vk::ImageLayout::eTransferDstOptimal : vk::ImageLayout::eUndefined, + performedTransfer ? vk::ImageLayout::eTransferDstOptimal : vk::ImageLayout::eUndefined, vk::ImageLayout::eColorAttachmentOptimal, VK_QUEUE_FAMILY_IGNORED, VK_QUEUE_FAMILY_IGNORED, ctx.images[idx], range); - ctx.cmd.pipelineBarrier(src ? vk::PipelineStageFlagBits::eTransfer : vk::PipelineStageFlagBits::eTopOfPipe, + ctx.cmd.pipelineBarrier(performedTransfer ? vk::PipelineStageFlagBits::eTransfer : vk::PipelineStageFlagBits::eTopOfPipe, vk::PipelineStageFlagBits::eColorAttachmentOutput, {}, nullptr, nullptr, toColor); EnsureFramebuffers(ctx); - if (!src) { - vk::ClearColorValue clearColor(std::array{0.f, 0.f, 0.f, 1.f}); - std::array ranges{range}; - ctx.cmd.clearColorImage(ctx.images[idx], vk::ImageLayout::eColorAttachmentOptimal, clearColor, ranges); - } - vk::RenderPassBeginInfo rpBegin{}; rpBegin.renderPass = ctx.renderPass; rpBegin.framebuffer = ctx.framebuffers[idx]; diff --git a/TensorFrost/src/Definitions/VulkanInterface.cpp b/TensorFrost/src/Definitions/VulkanInterface.cpp index ce5e20d8..1f205c6f 100644 --- a/TensorFrost/src/Definitions/VulkanInterface.cpp +++ b/TensorFrost/src/Definitions/VulkanInterface.cpp @@ -260,8 +260,15 @@ void PyWindow::present() { ::drawBuffer(window_, vk::Buffer{}, window_.extent.width, window_.extent.height, 0); } -py::tuple PyWindow::size() const { +py::tuple PyWindow::size() { ensureValid(); + int fbw = 0; + int fbh = 0; + glfwGetFramebufferSize(window_.wnd, &fbw, &fbh); + if (fbw > 0 && fbh > 0) { + window_.extent.width = static_cast(fbw); + window_.extent.height = static_cast(fbh); + } return py::make_tuple(window_.extent.width, window_.extent.height); } diff --git a/TensorFrost/src/Definitions/VulkanInterface.h b/TensorFrost/src/Definitions/VulkanInterface.h index 7c0441b2..03d5a905 100644 --- a/TensorFrost/src/Definitions/VulkanInterface.h +++ b/TensorFrost/src/Definitions/VulkanInterface.h @@ -100,7 +100,7 @@ class PyWindow { bool isOpen() const; void drawBuffer(const PyBuffer& buffer, uint32_t width, uint32_t height, size_t offset); void present(); - pybind11::tuple size() const; + pybind11::tuple size(); int format() const; void close(); From a988db6aa22cddd1eeebefe8bee1ae7111ba1f6e Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Thu, 6 Nov 2025 21:43:11 +0100 Subject: [PATCH 27/44] Move old examples into legacy examples --- examples/{ => legacy}/Algorithms/bitonic.ipynb | 0 .../{ => legacy}/Algorithms/custom_operation.py | 0 examples/{ => legacy}/Algorithms/fft.ipynb | 0 examples/{ => legacy}/Algorithms/fft_group.ipynb | 0 examples/{ => legacy}/Algorithms/indexing.py | 0 .../{ => legacy}/Algorithms/indexing_test.ipynb | 0 examples/{ => legacy}/Algorithms/kernels.ipynb | 0 examples/{ => legacy}/Algorithms/matrix_mul.ipynb | 0 examples/{ => legacy}/Algorithms/qr.ipynb | 0 .../Algorithms/random_number_generation.ipynb | 0 .../Algorithms/random_permutation.ipynb | 0 .../{ => legacy}/Algorithms/reshape_reduction.ipynb | 0 examples/{ => legacy}/Algorithms/scan.ipynb | 0 examples/{ => legacy}/Algorithms/scatter.py | 0 examples/{ => legacy}/Algorithms/sorting.ipynb | 0 examples/{ => legacy}/Demos/buddhabrot.gif | Bin examples/{ => legacy}/Demos/fluid_sim.gif | Bin examples/{ => legacy}/Demos/n_body.gif | Bin examples/{ => legacy}/Demos/nca.gif | Bin examples/{ => legacy}/Demos/neural_embed.gif | Bin examples/{ => legacy}/Demos/path_tracer.gif | Bin examples/{ => legacy}/Demos/sin_gordon.gif | Bin examples/{ => legacy}/GUI/buddhabrot.py | 0 examples/{ => legacy}/GUI/garden_smol.hdr | Bin examples/{ => legacy}/GUI/image_matcher.py | 0 .../{ => legacy}/GUI/interactive_path_tracer.py | 0 examples/{ => legacy}/ML/MNIST/MNIST.ipynb | 0 examples/{ => legacy}/ML/MNIST/loadMNIST.py | 0 examples/{ => legacy}/ML/MNIST/module.py | 0 examples/{ => legacy}/ML/MNIST/pytorch.py | 0 examples/{ => legacy}/ML/NCA/bugcat.png | Bin examples/{ => legacy}/ML/NCA/catthink.png | Bin examples/{ => legacy}/ML/NCA/inference.py | 0 examples/{ => legacy}/ML/NCA/nca.py | 0 examples/{ => legacy}/ML/NCA/shadertoy.py | 0 examples/{ => legacy}/ML/NCA/train.py | 0 examples/{ => legacy}/ML/VMC/atom.py | 0 examples/{ => legacy}/ML/VMC/camera.py | 0 examples/{ => legacy}/ML/VMC/logdet.py | 0 examples/{ => legacy}/ML/VMC/logdet_test.py | 0 examples/{ => legacy}/ML/VMC/molecules.py | 0 examples/{ => legacy}/ML/VMC/utils.py | 0 examples/{ => legacy}/ML/VMC/vec3.py | 0 examples/{ => legacy}/ML/VMC/visualizer.py | 0 examples/{ => legacy}/ML/VMC/visualizer_test.py | 0 examples/{ => legacy}/ML/VMC/vmc.py | 0 examples/{ => legacy}/Rendering/blur.ipynb | 0 examples/{ => legacy}/Rendering/buddhabrot.ipynb | 0 examples/{ => legacy}/Rendering/convolution.py | 0 examples/{ => legacy}/Rendering/fft2d.ipynb | 0 examples/{ => legacy}/Rendering/fft3d.py | 0 examples/{ => legacy}/Rendering/gaussian_grid.ipynb | 0 examples/{ => legacy}/Rendering/mandelbrot.ipynb | 0 examples/{ => legacy}/Rendering/neural_embed.ipynb | 0 examples/{ => legacy}/Rendering/neural_embed2.ipynb | 0 examples/{ => legacy}/Rendering/ray_marcher.ipynb | 0 examples/{ => legacy}/Rendering/sphere_tracer.ipynb | 0 examples/{ => legacy}/Rendering/test.png | Bin .../{ => legacy}/Simulation/fluid_simulation.ipynb | 0 .../{ => legacy}/Simulation/n-body-benchmark.py | 0 examples/{ => legacy}/Simulation/n-body.ipynb | 0 examples/{ => legacy}/Simulation/poission.py | 0 .../{ => legacy}/Simulation/wave_simulation.ipynb | 0 63 files changed, 0 insertions(+), 0 deletions(-) rename examples/{ => legacy}/Algorithms/bitonic.ipynb (100%) rename examples/{ => legacy}/Algorithms/custom_operation.py (100%) rename examples/{ => legacy}/Algorithms/fft.ipynb (100%) rename examples/{ => legacy}/Algorithms/fft_group.ipynb (100%) rename examples/{ => legacy}/Algorithms/indexing.py (100%) rename examples/{ => legacy}/Algorithms/indexing_test.ipynb (100%) rename examples/{ => legacy}/Algorithms/kernels.ipynb (100%) rename examples/{ => legacy}/Algorithms/matrix_mul.ipynb (100%) rename examples/{ => legacy}/Algorithms/qr.ipynb (100%) rename examples/{ => legacy}/Algorithms/random_number_generation.ipynb (100%) rename examples/{ => legacy}/Algorithms/random_permutation.ipynb (100%) rename examples/{ => legacy}/Algorithms/reshape_reduction.ipynb (100%) rename examples/{ => legacy}/Algorithms/scan.ipynb (100%) rename examples/{ => legacy}/Algorithms/scatter.py (100%) rename examples/{ => legacy}/Algorithms/sorting.ipynb (100%) rename examples/{ => legacy}/Demos/buddhabrot.gif (100%) rename examples/{ => legacy}/Demos/fluid_sim.gif (100%) rename examples/{ => legacy}/Demos/n_body.gif (100%) rename examples/{ => legacy}/Demos/nca.gif (100%) rename examples/{ => legacy}/Demos/neural_embed.gif (100%) rename examples/{ => legacy}/Demos/path_tracer.gif (100%) rename examples/{ => legacy}/Demos/sin_gordon.gif (100%) rename examples/{ => legacy}/GUI/buddhabrot.py (100%) rename examples/{ => legacy}/GUI/garden_smol.hdr (100%) rename examples/{ => legacy}/GUI/image_matcher.py (100%) rename examples/{ => legacy}/GUI/interactive_path_tracer.py (100%) rename examples/{ => legacy}/ML/MNIST/MNIST.ipynb (100%) rename examples/{ => legacy}/ML/MNIST/loadMNIST.py (100%) rename examples/{ => legacy}/ML/MNIST/module.py (100%) rename examples/{ => legacy}/ML/MNIST/pytorch.py (100%) rename examples/{ => legacy}/ML/NCA/bugcat.png (100%) rename examples/{ => legacy}/ML/NCA/catthink.png (100%) rename examples/{ => legacy}/ML/NCA/inference.py (100%) rename examples/{ => legacy}/ML/NCA/nca.py (100%) rename examples/{ => legacy}/ML/NCA/shadertoy.py (100%) rename examples/{ => legacy}/ML/NCA/train.py (100%) rename examples/{ => legacy}/ML/VMC/atom.py (100%) rename examples/{ => legacy}/ML/VMC/camera.py (100%) rename examples/{ => legacy}/ML/VMC/logdet.py (100%) rename examples/{ => legacy}/ML/VMC/logdet_test.py (100%) rename examples/{ => legacy}/ML/VMC/molecules.py (100%) rename examples/{ => legacy}/ML/VMC/utils.py (100%) rename examples/{ => legacy}/ML/VMC/vec3.py (100%) rename examples/{ => legacy}/ML/VMC/visualizer.py (100%) rename examples/{ => legacy}/ML/VMC/visualizer_test.py (100%) rename examples/{ => legacy}/ML/VMC/vmc.py (100%) rename examples/{ => legacy}/Rendering/blur.ipynb (100%) rename examples/{ => legacy}/Rendering/buddhabrot.ipynb (100%) rename examples/{ => legacy}/Rendering/convolution.py (100%) rename examples/{ => legacy}/Rendering/fft2d.ipynb (100%) rename examples/{ => legacy}/Rendering/fft3d.py (100%) rename examples/{ => legacy}/Rendering/gaussian_grid.ipynb (100%) rename examples/{ => legacy}/Rendering/mandelbrot.ipynb (100%) rename examples/{ => legacy}/Rendering/neural_embed.ipynb (100%) rename examples/{ => legacy}/Rendering/neural_embed2.ipynb (100%) rename examples/{ => legacy}/Rendering/ray_marcher.ipynb (100%) rename examples/{ => legacy}/Rendering/sphere_tracer.ipynb (100%) rename examples/{ => legacy}/Rendering/test.png (100%) rename examples/{ => legacy}/Simulation/fluid_simulation.ipynb (100%) rename examples/{ => legacy}/Simulation/n-body-benchmark.py (100%) rename examples/{ => legacy}/Simulation/n-body.ipynb (100%) rename examples/{ => legacy}/Simulation/poission.py (100%) rename examples/{ => legacy}/Simulation/wave_simulation.ipynb (100%) diff --git a/examples/Algorithms/bitonic.ipynb b/examples/legacy/Algorithms/bitonic.ipynb similarity index 100% rename from examples/Algorithms/bitonic.ipynb rename to examples/legacy/Algorithms/bitonic.ipynb diff --git a/examples/Algorithms/custom_operation.py b/examples/legacy/Algorithms/custom_operation.py similarity index 100% rename from examples/Algorithms/custom_operation.py rename to examples/legacy/Algorithms/custom_operation.py diff --git a/examples/Algorithms/fft.ipynb b/examples/legacy/Algorithms/fft.ipynb similarity index 100% rename from examples/Algorithms/fft.ipynb rename to examples/legacy/Algorithms/fft.ipynb diff --git a/examples/Algorithms/fft_group.ipynb b/examples/legacy/Algorithms/fft_group.ipynb similarity index 100% rename from examples/Algorithms/fft_group.ipynb rename to examples/legacy/Algorithms/fft_group.ipynb diff --git a/examples/Algorithms/indexing.py b/examples/legacy/Algorithms/indexing.py similarity index 100% rename from examples/Algorithms/indexing.py rename to examples/legacy/Algorithms/indexing.py diff --git a/examples/Algorithms/indexing_test.ipynb b/examples/legacy/Algorithms/indexing_test.ipynb similarity index 100% rename from examples/Algorithms/indexing_test.ipynb rename to examples/legacy/Algorithms/indexing_test.ipynb diff --git a/examples/Algorithms/kernels.ipynb b/examples/legacy/Algorithms/kernels.ipynb similarity index 100% rename from examples/Algorithms/kernels.ipynb rename to examples/legacy/Algorithms/kernels.ipynb diff --git a/examples/Algorithms/matrix_mul.ipynb b/examples/legacy/Algorithms/matrix_mul.ipynb similarity index 100% rename from examples/Algorithms/matrix_mul.ipynb rename to examples/legacy/Algorithms/matrix_mul.ipynb diff --git a/examples/Algorithms/qr.ipynb b/examples/legacy/Algorithms/qr.ipynb similarity index 100% rename from examples/Algorithms/qr.ipynb rename to examples/legacy/Algorithms/qr.ipynb diff --git a/examples/Algorithms/random_number_generation.ipynb b/examples/legacy/Algorithms/random_number_generation.ipynb similarity index 100% rename from examples/Algorithms/random_number_generation.ipynb rename to examples/legacy/Algorithms/random_number_generation.ipynb diff --git a/examples/Algorithms/random_permutation.ipynb b/examples/legacy/Algorithms/random_permutation.ipynb similarity index 100% rename from examples/Algorithms/random_permutation.ipynb rename to examples/legacy/Algorithms/random_permutation.ipynb diff --git a/examples/Algorithms/reshape_reduction.ipynb b/examples/legacy/Algorithms/reshape_reduction.ipynb similarity index 100% rename from examples/Algorithms/reshape_reduction.ipynb rename to examples/legacy/Algorithms/reshape_reduction.ipynb diff --git a/examples/Algorithms/scan.ipynb b/examples/legacy/Algorithms/scan.ipynb similarity index 100% rename from examples/Algorithms/scan.ipynb rename to examples/legacy/Algorithms/scan.ipynb diff --git a/examples/Algorithms/scatter.py b/examples/legacy/Algorithms/scatter.py similarity index 100% rename from examples/Algorithms/scatter.py rename to examples/legacy/Algorithms/scatter.py diff --git a/examples/Algorithms/sorting.ipynb b/examples/legacy/Algorithms/sorting.ipynb similarity index 100% rename from examples/Algorithms/sorting.ipynb rename to examples/legacy/Algorithms/sorting.ipynb diff --git a/examples/Demos/buddhabrot.gif b/examples/legacy/Demos/buddhabrot.gif similarity index 100% rename from examples/Demos/buddhabrot.gif rename to examples/legacy/Demos/buddhabrot.gif diff --git a/examples/Demos/fluid_sim.gif b/examples/legacy/Demos/fluid_sim.gif similarity index 100% rename from examples/Demos/fluid_sim.gif rename to examples/legacy/Demos/fluid_sim.gif diff --git a/examples/Demos/n_body.gif b/examples/legacy/Demos/n_body.gif similarity index 100% rename from examples/Demos/n_body.gif rename to examples/legacy/Demos/n_body.gif diff --git a/examples/Demos/nca.gif b/examples/legacy/Demos/nca.gif similarity index 100% rename from examples/Demos/nca.gif rename to examples/legacy/Demos/nca.gif diff --git a/examples/Demos/neural_embed.gif b/examples/legacy/Demos/neural_embed.gif similarity index 100% rename from examples/Demos/neural_embed.gif rename to examples/legacy/Demos/neural_embed.gif diff --git a/examples/Demos/path_tracer.gif b/examples/legacy/Demos/path_tracer.gif similarity index 100% rename from examples/Demos/path_tracer.gif rename to examples/legacy/Demos/path_tracer.gif diff --git a/examples/Demos/sin_gordon.gif b/examples/legacy/Demos/sin_gordon.gif similarity index 100% rename from examples/Demos/sin_gordon.gif rename to examples/legacy/Demos/sin_gordon.gif diff --git a/examples/GUI/buddhabrot.py b/examples/legacy/GUI/buddhabrot.py similarity index 100% rename from examples/GUI/buddhabrot.py rename to examples/legacy/GUI/buddhabrot.py diff --git a/examples/GUI/garden_smol.hdr b/examples/legacy/GUI/garden_smol.hdr similarity index 100% rename from examples/GUI/garden_smol.hdr rename to examples/legacy/GUI/garden_smol.hdr diff --git a/examples/GUI/image_matcher.py b/examples/legacy/GUI/image_matcher.py similarity index 100% rename from examples/GUI/image_matcher.py rename to examples/legacy/GUI/image_matcher.py diff --git a/examples/GUI/interactive_path_tracer.py b/examples/legacy/GUI/interactive_path_tracer.py similarity index 100% rename from examples/GUI/interactive_path_tracer.py rename to examples/legacy/GUI/interactive_path_tracer.py diff --git a/examples/ML/MNIST/MNIST.ipynb b/examples/legacy/ML/MNIST/MNIST.ipynb similarity index 100% rename from examples/ML/MNIST/MNIST.ipynb rename to examples/legacy/ML/MNIST/MNIST.ipynb diff --git a/examples/ML/MNIST/loadMNIST.py b/examples/legacy/ML/MNIST/loadMNIST.py similarity index 100% rename from examples/ML/MNIST/loadMNIST.py rename to examples/legacy/ML/MNIST/loadMNIST.py diff --git a/examples/ML/MNIST/module.py b/examples/legacy/ML/MNIST/module.py similarity index 100% rename from examples/ML/MNIST/module.py rename to examples/legacy/ML/MNIST/module.py diff --git a/examples/ML/MNIST/pytorch.py b/examples/legacy/ML/MNIST/pytorch.py similarity index 100% rename from examples/ML/MNIST/pytorch.py rename to examples/legacy/ML/MNIST/pytorch.py diff --git a/examples/ML/NCA/bugcat.png b/examples/legacy/ML/NCA/bugcat.png similarity index 100% rename from examples/ML/NCA/bugcat.png rename to examples/legacy/ML/NCA/bugcat.png diff --git a/examples/ML/NCA/catthink.png b/examples/legacy/ML/NCA/catthink.png similarity index 100% rename from examples/ML/NCA/catthink.png rename to examples/legacy/ML/NCA/catthink.png diff --git a/examples/ML/NCA/inference.py b/examples/legacy/ML/NCA/inference.py similarity index 100% rename from examples/ML/NCA/inference.py rename to examples/legacy/ML/NCA/inference.py diff --git a/examples/ML/NCA/nca.py b/examples/legacy/ML/NCA/nca.py similarity index 100% rename from examples/ML/NCA/nca.py rename to examples/legacy/ML/NCA/nca.py diff --git a/examples/ML/NCA/shadertoy.py b/examples/legacy/ML/NCA/shadertoy.py similarity index 100% rename from examples/ML/NCA/shadertoy.py rename to examples/legacy/ML/NCA/shadertoy.py diff --git a/examples/ML/NCA/train.py b/examples/legacy/ML/NCA/train.py similarity index 100% rename from examples/ML/NCA/train.py rename to examples/legacy/ML/NCA/train.py diff --git a/examples/ML/VMC/atom.py b/examples/legacy/ML/VMC/atom.py similarity index 100% rename from examples/ML/VMC/atom.py rename to examples/legacy/ML/VMC/atom.py diff --git a/examples/ML/VMC/camera.py b/examples/legacy/ML/VMC/camera.py similarity index 100% rename from examples/ML/VMC/camera.py rename to examples/legacy/ML/VMC/camera.py diff --git a/examples/ML/VMC/logdet.py b/examples/legacy/ML/VMC/logdet.py similarity index 100% rename from examples/ML/VMC/logdet.py rename to examples/legacy/ML/VMC/logdet.py diff --git a/examples/ML/VMC/logdet_test.py b/examples/legacy/ML/VMC/logdet_test.py similarity index 100% rename from examples/ML/VMC/logdet_test.py rename to examples/legacy/ML/VMC/logdet_test.py diff --git a/examples/ML/VMC/molecules.py b/examples/legacy/ML/VMC/molecules.py similarity index 100% rename from examples/ML/VMC/molecules.py rename to examples/legacy/ML/VMC/molecules.py diff --git a/examples/ML/VMC/utils.py b/examples/legacy/ML/VMC/utils.py similarity index 100% rename from examples/ML/VMC/utils.py rename to examples/legacy/ML/VMC/utils.py diff --git a/examples/ML/VMC/vec3.py b/examples/legacy/ML/VMC/vec3.py similarity index 100% rename from examples/ML/VMC/vec3.py rename to examples/legacy/ML/VMC/vec3.py diff --git a/examples/ML/VMC/visualizer.py b/examples/legacy/ML/VMC/visualizer.py similarity index 100% rename from examples/ML/VMC/visualizer.py rename to examples/legacy/ML/VMC/visualizer.py diff --git a/examples/ML/VMC/visualizer_test.py b/examples/legacy/ML/VMC/visualizer_test.py similarity index 100% rename from examples/ML/VMC/visualizer_test.py rename to examples/legacy/ML/VMC/visualizer_test.py diff --git a/examples/ML/VMC/vmc.py b/examples/legacy/ML/VMC/vmc.py similarity index 100% rename from examples/ML/VMC/vmc.py rename to examples/legacy/ML/VMC/vmc.py diff --git a/examples/Rendering/blur.ipynb b/examples/legacy/Rendering/blur.ipynb similarity index 100% rename from examples/Rendering/blur.ipynb rename to examples/legacy/Rendering/blur.ipynb diff --git a/examples/Rendering/buddhabrot.ipynb b/examples/legacy/Rendering/buddhabrot.ipynb similarity index 100% rename from examples/Rendering/buddhabrot.ipynb rename to examples/legacy/Rendering/buddhabrot.ipynb diff --git a/examples/Rendering/convolution.py b/examples/legacy/Rendering/convolution.py similarity index 100% rename from examples/Rendering/convolution.py rename to examples/legacy/Rendering/convolution.py diff --git a/examples/Rendering/fft2d.ipynb b/examples/legacy/Rendering/fft2d.ipynb similarity index 100% rename from examples/Rendering/fft2d.ipynb rename to examples/legacy/Rendering/fft2d.ipynb diff --git a/examples/Rendering/fft3d.py b/examples/legacy/Rendering/fft3d.py similarity index 100% rename from examples/Rendering/fft3d.py rename to examples/legacy/Rendering/fft3d.py diff --git a/examples/Rendering/gaussian_grid.ipynb b/examples/legacy/Rendering/gaussian_grid.ipynb similarity index 100% rename from examples/Rendering/gaussian_grid.ipynb rename to examples/legacy/Rendering/gaussian_grid.ipynb diff --git a/examples/Rendering/mandelbrot.ipynb b/examples/legacy/Rendering/mandelbrot.ipynb similarity index 100% rename from examples/Rendering/mandelbrot.ipynb rename to examples/legacy/Rendering/mandelbrot.ipynb diff --git a/examples/Rendering/neural_embed.ipynb b/examples/legacy/Rendering/neural_embed.ipynb similarity index 100% rename from examples/Rendering/neural_embed.ipynb rename to examples/legacy/Rendering/neural_embed.ipynb diff --git a/examples/Rendering/neural_embed2.ipynb b/examples/legacy/Rendering/neural_embed2.ipynb similarity index 100% rename from examples/Rendering/neural_embed2.ipynb rename to examples/legacy/Rendering/neural_embed2.ipynb diff --git a/examples/Rendering/ray_marcher.ipynb b/examples/legacy/Rendering/ray_marcher.ipynb similarity index 100% rename from examples/Rendering/ray_marcher.ipynb rename to examples/legacy/Rendering/ray_marcher.ipynb diff --git a/examples/Rendering/sphere_tracer.ipynb b/examples/legacy/Rendering/sphere_tracer.ipynb similarity index 100% rename from examples/Rendering/sphere_tracer.ipynb rename to examples/legacy/Rendering/sphere_tracer.ipynb diff --git a/examples/Rendering/test.png b/examples/legacy/Rendering/test.png similarity index 100% rename from examples/Rendering/test.png rename to examples/legacy/Rendering/test.png diff --git a/examples/Simulation/fluid_simulation.ipynb b/examples/legacy/Simulation/fluid_simulation.ipynb similarity index 100% rename from examples/Simulation/fluid_simulation.ipynb rename to examples/legacy/Simulation/fluid_simulation.ipynb diff --git a/examples/Simulation/n-body-benchmark.py b/examples/legacy/Simulation/n-body-benchmark.py similarity index 100% rename from examples/Simulation/n-body-benchmark.py rename to examples/legacy/Simulation/n-body-benchmark.py diff --git a/examples/Simulation/n-body.ipynb b/examples/legacy/Simulation/n-body.ipynb similarity index 100% rename from examples/Simulation/n-body.ipynb rename to examples/legacy/Simulation/n-body.ipynb diff --git a/examples/Simulation/poission.py b/examples/legacy/Simulation/poission.py similarity index 100% rename from examples/Simulation/poission.py rename to examples/legacy/Simulation/poission.py diff --git a/examples/Simulation/wave_simulation.ipynb b/examples/legacy/Simulation/wave_simulation.ipynb similarity index 100% rename from examples/Simulation/wave_simulation.ipynb rename to examples/legacy/Simulation/wave_simulation.ipynb From 1aa3a639698843f87c72fbb80c4933f492d692d3 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:17:27 +0100 Subject: [PATCH 28/44] Add additional ImGui features --- .idea/copilot.data.migration.agent.xml | 6 + .../Compiler/src/OperationRegistry.cpp | 31 +- .../src/Definitions/VulkanBindings.cpp | 100 ++++++- .../src/Definitions/VulkanInterface.cpp | 266 +++++++++++++++++ TensorFrost/src/Definitions/VulkanInterface.h | 64 +++++ examples/imgui_showcase.py | 271 ++++++++++++++++++ tests/imgui_test.py | 68 +++++ 7 files changed, 803 insertions(+), 3 deletions(-) create mode 100644 .idea/copilot.data.migration.agent.xml create mode 100644 examples/imgui_showcase.py diff --git a/.idea/copilot.data.migration.agent.xml b/.idea/copilot.data.migration.agent.xml new file mode 100644 index 00000000..89c07510 --- /dev/null +++ b/.idea/copilot.data.migration.agent.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/TensorFrost/Compiler/src/OperationRegistry.cpp b/TensorFrost/Compiler/src/OperationRegistry.cpp index e3076618..5e00db8c 100644 --- a/TensorFrost/Compiler/src/OperationRegistry.cpp +++ b/TensorFrost/Compiler/src/OperationRegistry.cpp @@ -1,3 +1,6 @@ +#include +#include + #include "Compiler/Operation.h" using namespace std; @@ -118,7 +121,7 @@ bool OpSpec::IsValid(const std::vector& inputs, TFDataFormat outpu #define BIN_OP_FOLD(op) \ make_fold2([](auto a, auto b) { \ - if constexpr (std::is_same_v || std::is_same_v) { \ + if constexpr (std::is_same_v, bool> || std::is_same_v, bool>) { \ return static_cast(a) op static_cast(b); \ } else { \ return a op b; \ @@ -131,8 +134,32 @@ make_fold2([](auto a, auto b) { \ #define UN_FUNC_FOLD(op) \ make_fold1([](auto a) { return op(a); }) +namespace { + +template +auto promote_for_compare(T&& value) { + using Decayed = std::decay_t; + if constexpr (std::is_same_v) { + return static_cast(value); + } else { + return static_cast(value); + } +} + +template +FoldFn make_compare_fold_impl(Compare cmp) { + return make_fold2([cmp = std::move(cmp)](auto a, auto b) -> Attribute { + auto lhs = promote_for_compare(std::forward(a)); + auto rhs = promote_for_compare(std::forward(b)); + using Common = std::common_type_t; + return (cmp)(static_cast(lhs), static_cast(rhs)); + }); +} + +} // namespace + #define BIN_FUNC_FOLD(op) \ - make_fold2([](auto a, auto b) { return op(a, b); }) + make_compare_fold_impl(op) #define TERN_FUNC_FOLD(op) \ make_fold3([](auto a, auto b, auto c) { return op(a, b, c); }) diff --git a/TensorFrost/src/Definitions/VulkanBindings.cpp b/TensorFrost/src/Definitions/VulkanBindings.cpp index 0a642a05..6e1550ab 100644 --- a/TensorFrost/src/Definitions/VulkanBindings.cpp +++ b/TensorFrost/src/Definitions/VulkanBindings.cpp @@ -118,7 +118,105 @@ void VulkanDefinitions(py::module_& m) { "Scale all ImGui sizes by a factor.") .def("imgui_add_background_text", &PyWindow::imguiAddBackgroundText, py::arg("text"), py::arg("pos"), py::arg("color"), - "Draw text in the background draw list."); + "Draw text in the background draw list.") + .def("imgui_same_line", &PyWindow::imguiSameLine, + py::arg("offset_from_start_x") = 0.0f, py::arg("spacing") = -1.0f, + "Place the next item on the same horizontal line.") + .def("imgui_separator", &PyWindow::imguiSeparator, + "Insert a separator line between items.") + .def("imgui_spacing", &PyWindow::imguiSpacing, + "Insert vertical spacing between items.") + .def("imgui_indent", &PyWindow::imguiIndent, + py::arg("indent_w") = 0.0f, + "Increase the current horizontal indent.") + .def("imgui_unindent", &PyWindow::imguiUnindent, + py::arg("indent_w") = 0.0f, + "Decrease the current horizontal indent.") + .def("imgui_begin_child", &PyWindow::imguiBeginChild, + py::arg("id"), py::arg("size") = py::none(), py::arg("border") = false, py::arg("flags") = 0, + "Begin a child region and return True if visible. Always pair with :meth:`imgui_end_child` even when the return value is False.") + .def("imgui_end_child", &PyWindow::imguiEndChild, + "End the current child region.") + .def("imgui_text_wrapped", &PyWindow::imguiTextWrapped, + py::arg("text"), + "Render wrapped text within the current column width.") + .def("imgui_text_colored", &PyWindow::imguiTextColored, + py::arg("color"), py::arg("text"), + "Render text with the given RGBA color.") + .def("imgui_bullet_text", &PyWindow::imguiBulletText, + py::arg("text"), + "Render text preceded by a bullet.") + .def("imgui_input_text", &PyWindow::imguiInputText, + py::arg("label"), py::arg("value"), py::arg("buffer_length") = 0, py::arg("flags") = 0, + "Input text returning (modified, value).") + .def("imgui_input_int", &PyWindow::imguiInputInt, + py::arg("label"), py::arg("value"), py::arg("step") = 1, py::arg("step_fast") = 100, py::arg("flags") = 0, + "Integer input returning the updated value.") + .def("imgui_input_float", &PyWindow::imguiInputFloat, + py::arg("label"), py::arg("value"), py::arg("step") = 0.0f, py::arg("step_fast") = 0.0f, + py::arg("format") = "%.3f", py::arg("flags") = 0, + "Float input returning the updated value.") + .def("imgui_color_edit3", &PyWindow::imguiColorEdit3, + py::arg("label"), py::arg("color"), py::arg("flags") = 0, + "Color editor returning (modified, rgb tuple).") + .def("imgui_color_edit4", &PyWindow::imguiColorEdit4, + py::arg("label"), py::arg("color"), py::arg("flags") = 0, + "Color editor returning (modified, rgba tuple).") + .def("imgui_begin_main_menu_bar", &PyWindow::imguiBeginMainMenuBar, + "Begin the global main menu bar, returning True if it is visible.") + .def("imgui_end_main_menu_bar", &PyWindow::imguiEndMainMenuBar, + "End the global main menu bar.") + .def("imgui_begin_menu_bar", &PyWindow::imguiBeginMenuBar, + "Begin a menu bar on the current window, returning True if visible.") + .def("imgui_end_menu_bar", &PyWindow::imguiEndMenuBar, + "End the current menu bar.") + .def("imgui_begin_menu", &PyWindow::imguiBeginMenu, + py::arg("label"), py::arg("enabled") = true, + "Begin a menu entry and return True if it is open.") + .def("imgui_end_menu", &PyWindow::imguiEndMenu, + "End the current menu entry.") + .def("imgui_menu_item", &PyWindow::imguiMenuItem, + py::arg("label"), py::arg("shortcut") = py::none(), py::arg("selected") = false, py::arg("enabled") = true, + "Create a menu item and return True when activated.") + .def("imgui_open_popup", &PyWindow::imguiOpenPopup, + py::arg("id"), py::arg("popup_flags") = 0, + "Open a popup window by identifier.") + .def("imgui_begin_popup", &PyWindow::imguiBeginPopup, + py::arg("id"), py::arg("flags") = 0, + "Begin a popup window, returning True if it is open.") + .def("imgui_begin_popup_modal", &PyWindow::imguiBeginPopupModal, + py::arg("name"), py::arg("open") = py::none(), py::arg("flags") = 0, + "Begin a modal popup, returning (visible, open_flag_or_None).") + .def("imgui_end_popup", &PyWindow::imguiEndPopup, + "End the current popup window.") + .def("imgui_close_current_popup", &PyWindow::imguiCloseCurrentPopup, + "Close the current popup window.") + .def("imgui_push_style_color", &PyWindow::imguiPushStyleColor, + py::arg("index"), py::arg("color"), + "Push a style color onto the stack.") + .def("imgui_pop_style_color", &PyWindow::imguiPopStyleColor, + py::arg("count") = 1, + "Pop style colors from the stack.") + .def("imgui_push_style_var_float", &PyWindow::imguiPushStyleVarFloat, + py::arg("index"), py::arg("value"), + "Push a float style variable onto the stack.") + .def("imgui_push_style_var_vec2", &PyWindow::imguiPushStyleVarVec2, + py::arg("index"), py::arg("value"), + "Push a 2D vector style variable onto the stack.") + .def("imgui_pop_style_var", &PyWindow::imguiPopStyleVar, + py::arg("count") = 1, + "Pop style variables from the stack.") + .def("imgui_get_font_global_scale", &PyWindow::imguiGetFontGlobalScale, + "Get the global font scale factor.") + .def("imgui_set_font_global_scale", &PyWindow::imguiSetFontGlobalScale, + py::arg("scale"), + "Set the global font scale factor.") + .def("imgui_get_style_color_vec4", &PyWindow::imguiGetStyleColorVec4, + py::arg("index"), + "Get a style color as an RGBA tuple.") + .def("imgui_set_style_color_vec4", &PyWindow::imguiSetStyleColorVec4, + py::arg("index"), py::arg("color"), + "Set a style color from an RGBA tuple."); m.def("createWindow", [](int width, int height, const std::string& title) { diff --git a/TensorFrost/src/Definitions/VulkanInterface.cpp b/TensorFrost/src/Definitions/VulkanInterface.cpp index 1f205c6f..094346df 100644 --- a/TensorFrost/src/Definitions/VulkanInterface.cpp +++ b/TensorFrost/src/Definitions/VulkanInterface.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include @@ -366,6 +367,250 @@ void PyWindow::imguiAddBackgroundText(const std::string& text, text.c_str()); } +void PyWindow::imguiSameLine(float offsetFromStartX, float spacing) { + bindImGui(); + ImGui::SameLine(offsetFromStartX, spacing); +} + +void PyWindow::imguiSeparator() { + bindImGui(); + ImGui::Separator(); +} + +void PyWindow::imguiSpacing() { + bindImGui(); + ImGui::Spacing(); +} + +void PyWindow::imguiIndent(float indentW) { + bindImGui(); + ImGui::Indent(indentW); +} + +void PyWindow::imguiUnindent(float indentW) { + bindImGui(); + ImGui::Unindent(indentW); +} + +bool PyWindow::imguiBeginChild(const std::string& id, + const py::object& size, + bool border, + int flags) { + bindImGui(); + ImVec2 vecSize = objectToVec2(size, "size"); + return ImGui::BeginChild(id.c_str(), vecSize, border, flags); +} + +void PyWindow::imguiEndChild() { + bindImGui(); + ImGui::EndChild(); +} + +void PyWindow::imguiTextWrapped(const std::string& text) { + bindImGui(); + ImGui::TextWrapped("%s", text.c_str()); +} + +void PyWindow::imguiTextColored(py::tuple color, + const std::string& text) { + bindImGui(); + ImVec4 col = tupleToVec4(color, "color"); + ImGui::TextColored(col, "%s", text.c_str()); +} + +void PyWindow::imguiBulletText(const std::string& text) { + bindImGui(); + ImGui::BulletText("%s", text.c_str()); +} + +std::tuple PyWindow::imguiInputText(const std::string& label, + const std::string& value, + size_t bufferLength, + int flags) { + bindImGui(); + size_t minimum = value.size() + 1; + size_t capacity = bufferLength ? std::max(bufferLength, minimum) : std::max(minimum, value.size() + 256); + if (capacity == 0) capacity = 1; + std::vector buffer(capacity, '\0'); + std::copy(value.begin(), value.end(), buffer.begin()); + bool edited = ImGui::InputText(label.c_str(), buffer.data(), buffer.size(), flags); + std::string result(buffer.data()); + return std::make_tuple(edited, std::move(result)); +} + +int PyWindow::imguiInputInt(const std::string& label, int value, int step, int stepFast, int flags) { + bindImGui(); + int v = value; + ImGui::InputInt(label.c_str(), &v, step, stepFast, flags); + return v; +} + +float PyWindow::imguiInputFloat(const std::string& label, + float value, + float step, + float stepFast, + const std::string& format, + int flags) { + bindImGui(); + float v = value; + ImGui::InputFloat(label.c_str(), &v, step, stepFast, format.c_str(), flags); + return v; +} + +std::tuple PyWindow::imguiColorEdit3(const std::string& label, + py::tuple color, + int flags) { + bindImGui(); + validateTupleSize(color, 3, "color"); + float col[3] = { + color[0].cast(), + color[1].cast(), + color[2].cast() + }; + bool changed = ImGui::ColorEdit3(label.c_str(), col, flags); + return std::make_tuple(changed, py::make_tuple(col[0], col[1], col[2])); +} + +std::tuple PyWindow::imguiColorEdit4(const std::string& label, + py::tuple color, + int flags) { + bindImGui(); + validateTupleSize(color, 4, "color"); + float col[4] = { + color[0].cast(), + color[1].cast(), + color[2].cast(), + color[3].cast() + }; + bool changed = ImGui::ColorEdit4(label.c_str(), col, flags); + return std::make_tuple(changed, py::make_tuple(col[0], col[1], col[2], col[3])); +} + +bool PyWindow::imguiBeginMainMenuBar() { + bindImGui(); + return ImGui::BeginMainMenuBar(); +} + +void PyWindow::imguiEndMainMenuBar() { + bindImGui(); + ImGui::EndMainMenuBar(); +} + +bool PyWindow::imguiBeginMenuBar() { + bindImGui(); + return ImGui::BeginMenuBar(); +} + +void PyWindow::imguiEndMenuBar() { + bindImGui(); + ImGui::EndMenuBar(); +} + +bool PyWindow::imguiBeginMenu(const std::string& label, bool enabled) { + bindImGui(); + return ImGui::BeginMenu(label.c_str(), enabled); +} + +void PyWindow::imguiEndMenu() { + bindImGui(); + ImGui::EndMenu(); +} + +bool PyWindow::imguiMenuItem(const std::string& label, + const py::object& shortcut, + bool selected, + bool enabled) { + bindImGui(); + std::string shortcutValue; + const char* shortcutPtr = nullptr; + if (!shortcut.is_none()) { + shortcutValue = shortcut.cast(); + shortcutPtr = shortcutValue.c_str(); + } + return ImGui::MenuItem(label.c_str(), shortcutPtr, selected, enabled); +} + +void PyWindow::imguiOpenPopup(const std::string& strId, int popupFlags) { + bindImGui(); + ImGui::OpenPopup(strId.c_str(), popupFlags); +} + +bool PyWindow::imguiBeginPopup(const std::string& strId, int flags) { + bindImGui(); + return ImGui::BeginPopup(strId.c_str(), flags); +} + +std::tuple PyWindow::imguiBeginPopupModal(const std::string& name, + const py::object& open, + int flags) { + bindImGui(); + bool openValue = open.is_none() ? true : open.cast(); + bool visible = ImGui::BeginPopupModal(name.c_str(), open.is_none() ? nullptr : &openValue, flags); + if (open.is_none()) { + return std::make_tuple(visible, py::none()); + } + return std::make_tuple(visible, py::cast(openValue)); +} + +void PyWindow::imguiEndPopup() { + bindImGui(); + ImGui::EndPopup(); +} + +void PyWindow::imguiCloseCurrentPopup() { + bindImGui(); + ImGui::CloseCurrentPopup(); +} + +void PyWindow::imguiPushStyleColor(int idx, py::tuple color) { + bindImGui(); + ImVec4 col = tupleToVec4(color, "color"); + ImGui::PushStyleColor(idx, col); +} + +void PyWindow::imguiPopStyleColor(int count) { + bindImGui(); + ImGui::PopStyleColor(count); +} + +void PyWindow::imguiPushStyleVarFloat(int idx, float value) { + bindImGui(); + ImGui::PushStyleVar(idx, value); +} + +void PyWindow::imguiPushStyleVarVec2(int idx, py::tuple value) { + bindImGui(); + ImVec2 vec = objectToVec2(value, "value"); + ImGui::PushStyleVar(idx, vec); +} + +void PyWindow::imguiPopStyleVar(int count) { + bindImGui(); + ImGui::PopStyleVar(count); +} + +float PyWindow::imguiGetFontGlobalScale() { + bindImGui(); + return ImGui::GetIO().FontGlobalScale; +} + +void PyWindow::imguiSetFontGlobalScale(float scale) { + bindImGui(); + ImGui::GetIO().FontGlobalScale = scale; +} + +py::tuple PyWindow::imguiGetStyleColorVec4(int idx) { + bindImGui(); + ImVec4 col = ImGui::GetStyle().Colors[idx]; + return vec4ToTuple(col); +} + +void PyWindow::imguiSetStyleColorVec4(int idx, py::tuple color) { + bindImGui(); + ImVec4 col = tupleToVec4(color, "color"); + ImGui::GetStyle().Colors[idx] = col; +} + void PyWindow::ensureValid() const { if (!window_.wnd) { throw std::runtime_error("Window has been closed"); @@ -394,6 +639,27 @@ void PyWindow::validateTupleSize(const py::tuple& tpl, size_t expected, const ch } } +ImVec2 PyWindow::objectToVec2(const py::object& obj, const char* name) { + if (obj.is_none()) { + return ImVec2(0.0f, 0.0f); + } + py::tuple tpl = obj.cast(); + validateTupleSize(tpl, 2, name); + return ImVec2(tpl[0].cast(), tpl[1].cast()); +} + +ImVec4 PyWindow::tupleToVec4(const py::tuple& tpl, const char* name) { + validateTupleSize(tpl, 4, name); + return ImVec4(tpl[0].cast(), + tpl[1].cast(), + tpl[2].cast(), + tpl[3].cast()); +} + +py::tuple PyWindow::vec4ToTuple(const ImVec4& vec) { + return py::make_tuple(vec.x, vec.y, vec.z, vec.w); +} + PyComputeProgram MakeComputeProgramFromGLSL(const std::string& source, uint32_t roCount, uint32_t rwCount) { diff --git a/TensorFrost/src/Definitions/VulkanInterface.h b/TensorFrost/src/Definitions/VulkanInterface.h index 03d5a905..feba7229 100644 --- a/TensorFrost/src/Definitions/VulkanInterface.h +++ b/TensorFrost/src/Definitions/VulkanInterface.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -13,6 +14,9 @@ #include "Backend/Vulkan.h" #include "Backend/Window.h" +struct ImVec2; +struct ImVec4; + namespace TensorFrost { class PyBuffer { @@ -125,12 +129,72 @@ class PyWindow { void imguiAddBackgroundText(const std::string& text, pybind11::tuple pos, pybind11::tuple color); + void imguiSameLine(float offsetFromStartX, float spacing); + void imguiSeparator(); + void imguiSpacing(); + void imguiIndent(float indentW); + void imguiUnindent(float indentW); + bool imguiBeginChild(const std::string& id, + const pybind11::object& size, + bool border, + int flags); + void imguiEndChild(); + void imguiTextWrapped(const std::string& text); + void imguiTextColored(pybind11::tuple color, + const std::string& text); + void imguiBulletText(const std::string& text); + std::tuple imguiInputText(const std::string& label, + const std::string& value, + size_t bufferLength, + int flags); + int imguiInputInt(const std::string& label, int value, int step, int stepFast, int flags); + float imguiInputFloat(const std::string& label, + float value, + float step, + float stepFast, + const std::string& format, + int flags); + std::tuple imguiColorEdit3(const std::string& label, + pybind11::tuple color, + int flags); + std::tuple imguiColorEdit4(const std::string& label, + pybind11::tuple color, + int flags); + bool imguiBeginMainMenuBar(); + void imguiEndMainMenuBar(); + bool imguiBeginMenuBar(); + void imguiEndMenuBar(); + bool imguiBeginMenu(const std::string& label, bool enabled); + void imguiEndMenu(); + bool imguiMenuItem(const std::string& label, + const pybind11::object& shortcut, + bool selected, + bool enabled); + void imguiOpenPopup(const std::string& strId, int popupFlags); + bool imguiBeginPopup(const std::string& strId, int flags); + std::tuple imguiBeginPopupModal(const std::string& name, + const pybind11::object& open, + int flags); + void imguiEndPopup(); + void imguiCloseCurrentPopup(); + void imguiPushStyleColor(int idx, pybind11::tuple color); + void imguiPopStyleColor(int count); + void imguiPushStyleVarFloat(int idx, float value); + void imguiPushStyleVarVec2(int idx, pybind11::tuple value); + void imguiPopStyleVar(int count); + float imguiGetFontGlobalScale(); + void imguiSetFontGlobalScale(float scale); + pybind11::tuple imguiGetStyleColorVec4(int idx); + void imguiSetStyleColorVec4(int idx, pybind11::tuple color); private: void ensureValid() const; void moveFrom(PyWindow&& other); ImGuiContext* bindImGui(); static void validateTupleSize(const pybind11::tuple& tpl, size_t expected, const char* name); + static ImVec2 objectToVec2(const pybind11::object& obj, const char* name); + static ImVec4 tupleToVec4(const pybind11::tuple& tpl, const char* name); + static pybind11::tuple vec4ToTuple(const ImVec4& vec); VulkanContext* ctx_{}; WindowContext window_{}; diff --git a/examples/imgui_showcase.py b/examples/imgui_showcase.py new file mode 100644 index 00000000..c192eb05 --- /dev/null +++ b/examples/imgui_showcase.py @@ -0,0 +1,271 @@ +"""Interactive ImGui showcase demonstrating the TensorFrost Vulkan window helpers. + +This example exercises a large portion of the Python ImGui bindings that are +exposed through ``TensorFrost.Window``. +""" + +from __future__ import annotations + +import math +import time +from collections import deque + +import numpy as np +import TensorFrost as tf + +# ImGui enums we need inside the sample. They mirror the values from imgui.h. +IMGUI_WINDOW_FLAGS_MENU_BAR = 1 << 3 +IMGUI_COL_WINDOW_BG = 2 +IMGUI_COL_BUTTON = 21 +IMGUI_COL_BUTTON_HOVERED = 22 +IMGUI_STYLEVAR_WINDOW_PADDING = 2 +IMGUI_STYLEVAR_FRAME_ROUNDING = 12 +IMGUI_STYLEVAR_ITEM_SPACING = 14 + + +def main() -> None: + width, height = 960, 600 + window = tf.createWindow(width, height, "TensorFrost ImGui Showcase") + + sample_history: deque[float] = deque(maxlen=512) + frame_times: deque[float] = deque(maxlen=240) + + start_time = time.perf_counter() + last_time = start_time + total_time = 0.0 + + state = { + "animate": True, + "show_plot": True, + "wave_speed": 1.0, + "wave_scale": 1.0, + "sample_count": 180, + "greeting": "Hello from TensorFrost!", + "accent": (0.25, 0.62, 0.98, 1.0), + "theme": "dark", + "font_scale": 2.0, + } + + themes = { + "dark": { + IMGUI_COL_WINDOW_BG: (0.1, 0.12, 0.16, 1.0), + IMGUI_COL_BUTTON: (0.27, 0.44, 0.85, 1.0), + IMGUI_COL_BUTTON_HOVERED: (0.36, 0.53, 0.92, 1.0), + }, + "light": { + IMGUI_COL_WINDOW_BG: (0.95, 0.96, 1.0, 1.0), + IMGUI_COL_BUTTON: (0.60, 0.78, 1.0, 1.0), + IMGUI_COL_BUTTON_HOVERED: (0.47, 0.66, 0.94, 1.0), + }, + "retro": { + IMGUI_COL_WINDOW_BG: (0.12, 0.10, 0.08, 1.0), + IMGUI_COL_BUTTON: (0.93, 0.74, 0.28, 1.0), + IMGUI_COL_BUTTON_HOVERED: (0.98, 0.84, 0.39, 1.0), + }, + } + + def apply_theme(name: str) -> None: + colors = themes[name] + for col_idx, value in colors.items(): + window.imgui_set_style_color_vec4(col_idx, value) + + apply_theme(state["theme"]) + window.imgui_scale_all_sizes(2.0) + window.imgui_set_font_global_scale(state["font_scale"]) + + try: + while window.isOpen(): + now = time.perf_counter() + dt = now - last_time + last_time = now + total_time += dt + + frame_times.append(dt) + fps = len(frame_times) / sum(frame_times) if frame_times else 0.0 + + if state["animate"]: + sample = math.sin(total_time * state["wave_speed"]) * state["wave_scale"] + sample_history.append(sample) + else: + # Keep the history flat when paused so the plot stays visible. + if sample_history: + sample_history.append(sample_history[-1]) + else: + sample_history.append(0.0) + + # Keep only the user-requested number of samples from the history. + history_count = max(2, min(state["sample_count"], len(sample_history))) + history_array = np.array(list(sample_history)[-history_count:], dtype=np.float32) + + width, height = window.size + + if window.imgui_begin_main_menu_bar(): + if window.imgui_begin_menu("Theme"): + for name in themes: + if window.imgui_menu_item( + name.title(), shortcut=None, selected=state["theme"] == name, enabled=True + ): + state["theme"] = name + apply_theme(name) + window.imgui_end_menu() + + if window.imgui_begin_menu("View"): + if window.imgui_menu_item( + "Toggle animation", shortcut="A", selected=state["animate"], enabled=True + ): + state["animate"] = not state["animate"] + if window.imgui_menu_item( + "Show plot", shortcut="P", selected=state["show_plot"], enabled=True + ): + state["show_plot"] = not state["show_plot"] + if window.imgui_menu_item( + "Reset font scale", shortcut="Ctrl+0", selected=False, enabled=True + ): + state["font_scale"] = 2.0 + window.imgui_set_font_global_scale(state["font_scale"]) + window.imgui_end_menu() + + if window.imgui_begin_menu("Help"): + if window.imgui_menu_item("About TensorFrost", shortcut=None, selected=False, enabled=True): + window.imgui_open_popup("about_popup") + window.imgui_end_menu() + + window.imgui_end_main_menu_bar() + + about_visible, _ = window.imgui_begin_popup_modal("about_popup", open=None, flags=0) + if about_visible: + window.imgui_text_wrapped( + "TensorFrost ImGui showcase demonstrating the Python bindings for the Vulkan backend." + ) + window.imgui_spacing() + window.imgui_text("Bindings exercised:") + window.imgui_indent(12.0) + window.imgui_bullet_text("Main menu bar helpers") + window.imgui_bullet_text("Layout, widgets, and style stack APIs") + window.imgui_bullet_text("Background draw list utilities") + window.imgui_unindent(12.0) + window.imgui_spacing() + if window.imgui_button("Close##about"): + window.imgui_close_current_popup() + window.imgui_end_popup() + + visible, _ = window.imgui_begin( + "Control Center", open=None, flags=IMGUI_WINDOW_FLAGS_MENU_BAR + ) + if visible: + if window.imgui_begin_menu_bar(): + if window.imgui_begin_menu("View"): + if window.imgui_menu_item("Reset font scale", shortcut="Ctrl+0", selected=False, enabled=True): + state["font_scale"] = 2.0 + window.imgui_set_font_global_scale(state["font_scale"]) + if window.imgui_menu_item("Show plot", shortcut="P", selected=state["show_plot"], enabled=True): + state["show_plot"] = not state["show_plot"] + window.imgui_end_menu() + + window.imgui_end_menu_bar() + + window.imgui_text(f"Frame time: {dt * 1000.0:5.2f} ms") + window.imgui_text_colored((0.25, 0.85, 0.45, 1.0), f"Average FPS: {fps:5.1f}") + window.imgui_separator() + + controls_visible = window.imgui_begin_child("controls", (320, 280), border=True, flags=0) + if controls_visible: + edited, new_text = window.imgui_input_text("Greeting", state["greeting"], buffer_length=128) + if edited: + state["greeting"] = new_text + + state["animate"] = window.imgui_checkbox("Animate", state["animate"]) + window.imgui_same_line() + state["show_plot"] = window.imgui_checkbox("Show plot", state["show_plot"]) + + state["wave_speed"] = window.imgui_slider_float("Wave speed", state["wave_speed"], 0.1, 5.0) + state["wave_scale"] = window.imgui_slider_float("Wave scale", state["wave_scale"], 0.1, 3.0) + state["sample_count"] = window.imgui_slider_int("History samples", state["sample_count"], 20, sample_history.maxlen) + + window.imgui_spacing() + state["font_scale"] = window.imgui_slider_float("Font scale", state["font_scale"], 0.5, 2.0) + window.imgui_set_font_global_scale(state["font_scale"]) + + window.imgui_spacing() + changed, new_color = window.imgui_color_edit4("Accent color", state["accent"]) + if changed: + state["accent"] = tuple(new_color) + + window.imgui_spacing() + window.imgui_push_style_var_float(IMGUI_STYLEVAR_FRAME_ROUNDING, 8.0) + window.imgui_push_style_color(IMGUI_COL_BUTTON, state["accent"]) + window.imgui_push_style_color(IMGUI_COL_BUTTON_HOVERED, ( + min(1.0, state["accent"][0] + 0.1), + min(1.0, state["accent"][1] + 0.1), + min(1.0, state["accent"][2] + 0.1), + state["accent"][3], + )) + if window.imgui_button("Take snapshot"): + window.imgui_open_popup("snapshot_popup") + window.imgui_pop_style_color(2) + window.imgui_pop_style_var() + + visible_popup, _ = window.imgui_begin_popup_modal("snapshot_popup", open=None, flags=0) + if visible_popup: + window.imgui_text_wrapped( + "Pretend we stored the latest plot sample to disk." + ) + window.imgui_spacing() + if window.imgui_button("Close"): + window.imgui_close_current_popup() + window.imgui_end_popup() + window.imgui_end_child() + + window.imgui_same_line() + + details_visible = window.imgui_begin_child("details", (0, 280), border=True, flags=0) + if details_visible: + window.imgui_text("Details") + window.imgui_separator() + window.imgui_text_wrapped( + "The live sine wave demonstrates sliders, checkboxes, plot widgets, " + "menus, popups, style stacks, and color editing exposed through the Vulkan backend." + ) + + window.imgui_spacing() + window.imgui_indent(10.0) + window.imgui_bullet_text("Toggle themes from the menu bar.") + window.imgui_bullet_text("Drag the sliders to shape the waveform.") + window.imgui_bullet_text("Use the accent color to restyle the snapshot button.") + window.imgui_unindent(10.0) + + window.imgui_spacing() + window.imgui_text_colored(state["accent"], f"Greeting: {state['greeting']}") + + window.imgui_spacing() + window.imgui_push_style_var_vec2(IMGUI_STYLEVAR_WINDOW_PADDING, (12.0, 12.0)) + window.imgui_text("Background text is rendered via the overlay draw list.") + window.imgui_pop_style_var() + window.imgui_end_child() + + window.imgui_separator() + + if state["show_plot"] and history_array.size > 1: + window.imgui_plot_lines( + "Sine wave", history_array, values_offset=0, overlay_text="Normalized", scale_min=-3.0, + scale_max=3.0, graph_size=(0.0, 140.0), stride=4 + ) + + window.imgui_spacing() + window.imgui_text(f"Window size: {width} x {height}") + + window.imgui_end() + + window.imgui_add_background_text( + state["greeting"], + pos=(24.0, 24.0), + color=(state["accent"][0], state["accent"][1], state["accent"][2], 0.12), + ) + + window.present() + finally: + window.close() + + +if __name__ == "__main__": + main() diff --git a/tests/imgui_test.py b/tests/imgui_test.py index 2c150139..371db5c2 100644 --- a/tests/imgui_test.py +++ b/tests/imgui_test.py @@ -37,11 +37,40 @@ def managed_window(width=320, height=240, title="ImGui Test Window"): class ImGuiIntegrationTest(unittest.TestCase): def test_imgui_basic_widgets(self): with managed_window() as win: + if win.imgui_begin_main_menu_bar(): + if win.imgui_begin_menu("Root"): + self.assertIn(win.imgui_menu_item("Item"), (True, False)) + win.imgui_end_menu() + win.imgui_end_main_menu_bar() + visible, open_flag = win.imgui_begin("Main Panel") self.assertTrue(visible, "ImGui window should be visible on begin") self.assertIsNone(open_flag, "Default begin call should return None for open flag") win.imgui_text("Hello from TensorFrost tests") + win.imgui_same_line() + win.imgui_text("Inline") + win.imgui_spacing() + win.imgui_separator() + win.imgui_indent() + child_visible = win.imgui_begin_child("child", size=(120.0, 60.0), border=False) + self.assertIsInstance(child_visible, bool) + if child_visible: + win.imgui_text_wrapped("Wrapped text inside child region to check bindings work as expected.") + win.imgui_bullet_text("Bullet item content") + win.imgui_end_child() + win.imgui_unindent() + win.imgui_text_colored((1.0, 0.0, 0.0, 1.0), "Colored text") + alpha_scale = win.imgui_get_font_global_scale() + win.imgui_set_font_global_scale(alpha_scale) + style_color = win.imgui_get_style_color_vec4(0) + self.assertEqual(len(style_color), 4) + win.imgui_push_style_color(0, style_color) + win.imgui_push_style_var_float(0, 0.95) + win.imgui_push_style_var_vec2(2, (4.0, 4.0)) + win.imgui_pop_style_var(2) + win.imgui_pop_style_color() + win.imgui_set_style_color_vec4(0, style_color) button_pressed = win.imgui_button("Press Me") self.assertIn(button_pressed, (True, False)) @@ -55,6 +84,45 @@ def test_imgui_basic_widgets(self): data = np.linspace(0.0, 1.0, 32, dtype=np.float32) win.imgui_plot_lines("Plot", data, overlay_text="test", graph_size=(120.0, 40.0)) + text_changed, text_value = win.imgui_input_text("Input", "hello") + self.assertIsInstance(text_changed, bool) + self.assertIsInstance(text_value, str) + + updated_int_input = win.imgui_input_int("Input Int", 10) + self.assertIsInstance(updated_int_input, int) + + updated_float_input = win.imgui_input_float("Input Float", 3.14) + self.assertIsInstance(updated_float_input, float) + + color_changed3, rgb = win.imgui_color_edit3("Color3", (0.1, 0.2, 0.3)) + self.assertIsInstance(color_changed3, bool) + self.assertEqual(len(rgb), 3) + + color_changed4, rgba = win.imgui_color_edit4("Color4", (0.1, 0.2, 0.3, 0.4)) + self.assertIsInstance(color_changed4, bool) + self.assertEqual(len(rgba), 4) + + if win.imgui_begin_menu_bar(): + if win.imgui_begin_menu("File"): + self.assertIn(win.imgui_menu_item("New"), (True, False)) + win.imgui_end_menu() + win.imgui_end_menu_bar() + + win.imgui_open_popup("ContextPopup") + if win.imgui_begin_popup("ContextPopup"): + win.imgui_text("Popup body") + win.imgui_close_current_popup() + win.imgui_end_popup() + + win.imgui_open_popup("ModalPopup") + visible_modal, modal_open = win.imgui_begin_popup_modal("ModalPopup", open=True) + self.assertIn(visible_modal, (True, False)) + self.assertIsInstance(modal_open, bool) + if visible_modal: + win.imgui_text("Modal body") + win.imgui_close_current_popup() + win.imgui_end_popup() + win.imgui_add_background_text("BG", (10.0, 10.0), (1.0, 1.0, 1.0, 1.0)) win.imgui_scale_all_sizes(1.0) win.imgui_end() From a301c73cdc25628e560cca387395d57188a6cd44 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Fri, 7 Nov 2025 00:08:41 +0100 Subject: [PATCH 29/44] Mouse control support added to window --- TensorFrost/Backend/include/Backend/Window.h | 24 ++++++++- TensorFrost/Backend/src/Window.cpp | 52 +++++++++++++++++++ .../src/Definitions/VulkanBindings.cpp | 11 +++- .../src/Definitions/VulkanInterface.cpp | 38 +++++++++++++- TensorFrost/src/Definitions/VulkanInterface.h | 5 ++ examples/Slang/mandelbrot.py | 41 ++++++++++++++- 6 files changed, 166 insertions(+), 5 deletions(-) diff --git a/TensorFrost/Backend/include/Backend/Window.h b/TensorFrost/Backend/include/Backend/Window.h index 9dd742f7..4414cee7 100644 --- a/TensorFrost/Backend/include/Backend/Window.h +++ b/TensorFrost/Backend/include/Backend/Window.h @@ -10,6 +10,11 @@ struct WindowContext; void ReleaseImGui(WindowContext& ctx); struct ImGuiContext; +namespace TFWindowDetail { +void RegisterScrollContext(GLFWwindow* wnd, WindowContext* ctx); +void UnregisterScrollContext(GLFWwindow* wnd); +} + struct WindowContext { GLFWwindow* wnd{}; vk::Instance instance; @@ -35,6 +40,9 @@ struct WindowContext { std::vector framebuffers; ImGuiContext* imguiContext{}; bool imguiFrameActive = false; + double scrollDeltaX = 0.0; + double scrollDeltaY = 0.0; + GLFWscrollfun prevScrollCallback = nullptr; WindowContext() = default; WindowContext(const WindowContext&) = delete; @@ -72,6 +80,12 @@ struct WindowContext { framebuffers=std::move(o.framebuffers); imguiContext=o.imguiContext; o.imguiContext=nullptr; imguiFrameActive=o.imguiFrameActive; o.imguiFrameActive=false; + scrollDeltaX = o.scrollDeltaX; o.scrollDeltaX = 0.0; + scrollDeltaY = o.scrollDeltaY; o.scrollDeltaY = 0.0; + prevScrollCallback = o.prevScrollCallback; o.prevScrollCallback = nullptr; + if (wnd) { + TFWindowDetail::RegisterScrollContext(wnd, this); + } } void cleanup() { @@ -95,9 +109,16 @@ struct WindowContext { if (swapchain) device.destroySwapchainKHR(swapchain), swapchain=nullptr; } if (surface) instance.destroySurfaceKHR(surface), surface=nullptr; - if (wnd) { glfwDestroyWindow(wnd); wnd=nullptr; } + if (wnd) { + TFWindowDetail::UnregisterScrollContext(wnd); + glfwDestroyWindow(wnd); + wnd=nullptr; + } // leave GLFW alive; app can call glfwTerminate() once at shutdown if it wants device=nullptr; instance=nullptr; queue=nullptr; + scrollDeltaX = 0.0; + scrollDeltaY = 0.0; + prevScrollCallback = nullptr; } }; @@ -105,6 +126,7 @@ WindowContext createWindow(int width, int height, const char* title); bool windowOpen(const WindowContext& ctx); void drawBuffer(WindowContext& ctx, vk::Buffer src, uint32_t width, uint32_t height, vk::DeviceSize offset = 0); void drawBuffer(WindowContext& ctx, const Buffer& b, uint32_t w, uint32_t h, size_t offset = 0); +void AttachWindowCallbacks(WindowContext& ctx); // Global helpers used by higher-level integrations / Python bindings WindowContext* GetWindow(); diff --git a/TensorFrost/Backend/src/Window.cpp b/TensorFrost/Backend/src/Window.cpp index 88de44c1..0cf981ae 100644 --- a/TensorFrost/Backend/src/Window.cpp +++ b/TensorFrost/Backend/src/Window.cpp @@ -9,6 +9,10 @@ #include #include #include +#include + +static std::unordered_map gScrollContexts; +static std::mutex gScrollMutex; namespace { std::unique_ptr gWindow; @@ -19,6 +23,28 @@ void CheckVkResult(VkResult err) { throw std::runtime_error("ImGui Vulkan backend error: " + std::to_string(err)); } +void ScrollCallback(GLFWwindow* wnd, double xoffset, double yoffset) { + WindowContext* ctx = nullptr; + GLFWscrollfun prev = nullptr; + { + std::scoped_lock lock(gScrollMutex); + auto it = gScrollContexts.find(wnd); + if (it != gScrollContexts.end()) { + ctx = it->second; + prev = ctx ? ctx->prevScrollCallback : nullptr; + } + } + + if (ctx) { + ctx->scrollDeltaX += xoffset; + ctx->scrollDeltaY += yoffset; + } + + if (prev && prev != ScrollCallback) { + prev(wnd, xoffset, yoffset); + } +} + void EnsureFramebuffers(WindowContext& ctx) { if (ctx.framebuffers.size() == ctx.images.size() && !ctx.framebuffers.empty()) return; @@ -242,6 +268,31 @@ void RecreateSwapchain(WindowContext& ctx, vk::Extent2D desiredExtent) { } } // namespace +namespace TFWindowDetail { + +void RegisterScrollContext(GLFWwindow* wnd, WindowContext* ctx) { + if (!wnd) return; + std::scoped_lock lock(gScrollMutex); + gScrollContexts[wnd] = ctx; +} + +void UnregisterScrollContext(GLFWwindow* wnd) { + if (!wnd) return; + std::scoped_lock lock(gScrollMutex); + gScrollContexts.erase(wnd); +} + +} // namespace TFWindowDetail + +void AttachWindowCallbacks(WindowContext& ctx) { + if (!ctx.wnd) return; + TFWindowDetail::RegisterScrollContext(ctx.wnd, &ctx); + GLFWscrollfun prev = glfwSetScrollCallback(ctx.wnd, ScrollCallback); + if (prev && prev != ScrollCallback) { + ctx.prevScrollCallback = prev; + } +} + void ReleaseImGui(WindowContext& ctx) { if (!ctx.imguiContext) return; ImGui::SetCurrentContext(ctx.imguiContext); @@ -484,6 +535,7 @@ void ShowWindow(int width, int height, const char* title) { std::scoped_lock lock(gWindowMutex); if (gWindow && windowOpen(*gWindow)) return; gWindow = std::make_unique(createWindow(width, height, title)); + AttachWindowCallbacks(*gWindow); } void HideWindow() { diff --git a/TensorFrost/src/Definitions/VulkanBindings.cpp b/TensorFrost/src/Definitions/VulkanBindings.cpp index 6e1550ab..28c682a3 100644 --- a/TensorFrost/src/Definitions/VulkanBindings.cpp +++ b/TensorFrost/src/Definitions/VulkanBindings.cpp @@ -216,7 +216,16 @@ void VulkanDefinitions(py::module_& m) { "Get a style color as an RGBA tuple.") .def("imgui_set_style_color_vec4", &PyWindow::imguiSetStyleColorVec4, py::arg("index"), py::arg("color"), - "Set a style color from an RGBA tuple."); + "Set a style color from an RGBA tuple.") + .def("mouse_position", &PyWindow::mousePosition, + "Get the current mouse position in window coordinates.") + .def("is_mouse_button_pressed", &PyWindow::isMouseButtonPressed, + py::arg("button"), + "Return True if the specified mouse button is pressed.") + .def("imgui_want_capture_mouse", &PyWindow::imguiWantCaptureMouse, + "Return True if ImGui wants to capture mouse input this frame.") + .def("consume_scroll_delta", &PyWindow::consumeScrollDelta, + "Consume the accumulated scroll delta as a tuple ``(x, y)`` and reset it to zero."); m.def("createWindow", [](int width, int height, const std::string& title) { diff --git a/TensorFrost/src/Definitions/VulkanInterface.cpp b/TensorFrost/src/Definitions/VulkanInterface.cpp index 094346df..40edb4d7 100644 --- a/TensorFrost/src/Definitions/VulkanInterface.cpp +++ b/TensorFrost/src/Definitions/VulkanInterface.cpp @@ -231,7 +231,9 @@ void PyComputeProgram::moveFrom(PyComputeProgram&& other) { } PyWindow::PyWindow(int width, int height, const std::string& title) - : ctx_(&getVulkanContext()), window_(createWindow(width, height, title.c_str())) {} + : ctx_(&getVulkanContext()), window_(createWindow(width, height, title.c_str())) { + AttachWindowCallbacks(window_); +} PyWindow::~PyWindow() = default; @@ -611,6 +613,37 @@ void PyWindow::imguiSetStyleColorVec4(int idx, py::tuple color) { ImGui::GetStyle().Colors[idx] = col; } +py::tuple PyWindow::mousePosition() { + ensureValid(); + double x = 0.0; + double y = 0.0; + glfwGetCursorPos(window_.wnd, &x, &y); + return py::make_tuple(x, y); +} + +bool PyWindow::isMouseButtonPressed(int button) { + ensureValid(); + return glfwGetMouseButton(window_.wnd, button) == GLFW_PRESS; +} + +bool PyWindow::imguiWantCaptureMouse() const { + ensureValid(); + if (!window_.imguiContext) { + return false; + } + ImGui::SetCurrentContext(window_.imguiContext); + return ImGui::GetIO().WantCaptureMouse; +} + +py::tuple PyWindow::consumeScrollDelta() { + ensureValid(); + double dx = window_.scrollDeltaX; + double dy = window_.scrollDeltaY; + window_.scrollDeltaX = 0.0; + window_.scrollDeltaY = 0.0; + return py::make_tuple(dx, dy); +} + void PyWindow::ensureValid() const { if (!window_.wnd) { throw std::runtime_error("Window has been closed"); @@ -621,6 +654,9 @@ void PyWindow::moveFrom(PyWindow&& other) { ctx_ = other.ctx_; window_ = std::move(other.window_); other.ctx_ = nullptr; + if (window_.wnd) { + TFWindowDetail::RegisterScrollContext(window_.wnd, &window_); + } } ImGuiContext* PyWindow::bindImGui() { diff --git a/TensorFrost/src/Definitions/VulkanInterface.h b/TensorFrost/src/Definitions/VulkanInterface.h index feba7229..9d299b06 100644 --- a/TensorFrost/src/Definitions/VulkanInterface.h +++ b/TensorFrost/src/Definitions/VulkanInterface.h @@ -187,6 +187,11 @@ class PyWindow { pybind11::tuple imguiGetStyleColorVec4(int idx); void imguiSetStyleColorVec4(int idx, pybind11::tuple color); + pybind11::tuple mousePosition(); + bool isMouseButtonPressed(int button); + bool imguiWantCaptureMouse() const; + pybind11::tuple consumeScrollDelta(); + private: void ensureValid() const; void moveFrom(PyWindow&& other); diff --git a/examples/Slang/mandelbrot.py b/examples/Slang/mandelbrot.py index 1d5e62ee..5575ce44 100644 --- a/examples/Slang/mandelbrot.py +++ b/examples/Slang/mandelbrot.py @@ -1,5 +1,6 @@ from pathlib import Path import math +import time import numpy as np import TensorFrost as tf @@ -13,6 +14,8 @@ def load_shader() -> str: def main() -> None: width, height = 1024, 768 win = tf.createWindow(width, height, "Mandelbrot (ImGui)") + win.imgui_scale_all_sizes(2.0) + win.imgui_set_font_global_scale(2.0) fmt = int(win.format) is_bgra_default = fmt in (44, 50) # VK_FORMAT_B8G8R8A8_UNORM / _SRGB @@ -27,6 +30,7 @@ def main() -> None: center = [-0.5, 0.0] scale = 3.0 log_scale = math.log10(scale) + pending_scroll = 0.0 manual_iterations = 500 auto_iterations = True swap_rb = is_bgra_default @@ -34,6 +38,10 @@ def main() -> None: history_index = 0 params = np.zeros(8, dtype=np.float32) + prev_mouse_pos = win.mouse_position() + dragging = False + prev_time = time.perf_counter() + fps = 0.0 def ensure_pixel_buffer(cur_width: int, cur_height: int) -> None: nonlocal pixel_buffer, pixel_capacity @@ -45,18 +53,30 @@ def ensure_pixel_buffer(cur_width: int, cur_height: int) -> None: try: while win.isOpen(): + now = time.perf_counter() + dt = max(now - prev_time, 1e-6) + prev_time = now + fps = fps * 0.9 + (1.0 / dt) * 0.1 + width, height = win.size width = max(1, int(width)) height = max(1, int(height)) ensure_pixel_buffer(width, height) + if pending_scroll: + scroll_adjust = pending_scroll * 0.12 + log_scale = max(min(log_scale - scroll_adjust, 0.5), -4.5) + scale = pow(10.0, log_scale) + pending_scroll = 0.0 + aspect = height / float(width) visible, _ = win.imgui_begin("Mandelbrot Controls", open=True) if visible: win.imgui_text(f"Resolution: {width} × {height}") win.imgui_text(f"Format: {fmt}") + win.imgui_text(f"FPS: {fps:5.1f} | {dt * 1000.0:.2f} ms") log_scale = win.imgui_slider_float("log₁₀ scale", log_scale, -4.5, 0.5) scale = pow(10.0, log_scale) center[0] = win.imgui_slider_float("Center X", center[0], -2.5, 1.5) @@ -75,6 +95,7 @@ def ensure_pixel_buffer(cur_width: int, cur_height: int) -> None: log_scale = math.log10(scale) manual_iterations = 500 auto_iterations = True + pending_scroll = 0.0 plot_history[history_index % plot_history.size] = float(manual_iterations) history_index += 1 @@ -88,11 +109,23 @@ def ensure_pixel_buffer(cur_width: int, cur_height: int) -> None: xspan = scale yspan = xspan * aspect - xmin = center[0] - xspan * 0.5 - ymin = center[1] - yspan * 0.5 dx = xspan / width dy = yspan / height + mouse_pos = win.mouse_position() + want_capture_mouse = win.imgui_want_capture_mouse() + mouse_down = win.is_mouse_button_pressed(0) + dragging_now = mouse_down and not want_capture_mouse + if dragging and dragging_now: + delta_x = mouse_pos[0] - prev_mouse_pos[0] + delta_y = mouse_pos[1] - prev_mouse_pos[1] + center[0] -= delta_x * dx + center[1] -= delta_y * dy + dragging = dragging_now + prev_mouse_pos = mouse_pos + + xmin = center[0] - xspan * 0.5 + ymin = center[1] - yspan * 0.5 params[0] = float(width) params[1] = float(height) params[2] = xmin @@ -106,6 +139,10 @@ def ensure_pixel_buffer(cur_width: int, cur_height: int) -> None: program.run([params_buffer], [pixel_buffer], width * height) win.drawBuffer(pixel_buffer, width, height) + + _, scroll_dy = win.consume_scroll_delta() + if not want_capture_mouse: + pending_scroll += scroll_dy finally: win.close() pixel_buffer.release() From e509dfadcddc461fe1668019627d5b675f29574e Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Fri, 7 Nov 2025 01:54:26 +0100 Subject: [PATCH 30/44] radix sort example --- Python/TensorFrost/__init__.py | 1 - Python/TensorFrost/sort.py | 396 ++++++++++++++++-- examples/radix_sort/__main__.py | 74 ++++ examples/radix_sort/shaders/__init__.py | 0 examples/radix_sort/shaders/bucket_scan.slang | 43 ++ examples/radix_sort/shaders/histogram.slang | 59 +++ .../radix_sort/shaders/map_from_uint.slang | 33 ++ examples/radix_sort/shaders/map_to_uint.slang | 33 ++ .../radix_sort/shaders/prefix_accum.slang | 41 ++ .../radix_sort/shaders/prefix_block.slang | 31 ++ .../radix_sort/shaders/prefix_local.slang | 47 +++ examples/radix_sort/shaders/scatter.slang | 104 +++++ examples/radix_sort/shaders/unpack.slang | 34 ++ examples/radix_sort/sort.py | 359 ++++++++++++++++ 14 files changed, 1217 insertions(+), 38 deletions(-) create mode 100644 examples/radix_sort/__main__.py create mode 100644 examples/radix_sort/shaders/__init__.py create mode 100644 examples/radix_sort/shaders/bucket_scan.slang create mode 100644 examples/radix_sort/shaders/histogram.slang create mode 100644 examples/radix_sort/shaders/map_from_uint.slang create mode 100644 examples/radix_sort/shaders/map_to_uint.slang create mode 100644 examples/radix_sort/shaders/prefix_accum.slang create mode 100644 examples/radix_sort/shaders/prefix_block.slang create mode 100644 examples/radix_sort/shaders/prefix_local.slang create mode 100644 examples/radix_sort/shaders/scatter.slang create mode 100644 examples/radix_sort/shaders/unpack.slang create mode 100644 examples/radix_sort/sort.py diff --git a/Python/TensorFrost/__init__.py b/Python/TensorFrost/__init__.py index 3fddb9ae..ce11ec6e 100644 --- a/Python/TensorFrost/__init__.py +++ b/Python/TensorFrost/__init__.py @@ -4,7 +4,6 @@ from . import regularizers from . import clipping from . import random -from . import sort from .default import * # def compile(func): diff --git a/Python/TensorFrost/sort.py b/Python/TensorFrost/sort.py index d3bfc143..277d130e 100644 --- a/Python/TensorFrost/sort.py +++ b/Python/TensorFrost/sort.py @@ -1,40 +1,362 @@ -# from . import TensorFrost as tf -# -# #in-place bitonic sort -# def bitonic(keys, values = None): -# tf.region_begin('Bitonic sort') -# keys = tf.copy(keys) -# if values is not None: -# values = tf.copy(values) -# element_count = keys.shape[0] -# log2_count = tf.int(tf.ceil(tf.log2(tf.float(element_count)))) -# count_round = 1 << log2_count -# idx = tf.indices([count_round / 2])[0] -# with tf.loop(log2_count) as k: -# with tf.loop(k+1) as j: -# s = 1 << (k-j) -# m_inner = s - 1 -# m_outer = ~m_inner -# m_xor = s + tf.select(j == 0, m_inner, 0) -# -# id1 = (2 * (idx & m_outer) + (idx & m_inner)) -# id2 = id1 ^ m_xor -# key1, key2 = keys[id1], keys[id2] -# with tf.if_cond((key1 >= key2) & (id1 < element_count) & (id2 < element_count)): -# if values is not None: -# val1, val2 = values[id1], values[id2] -# values[id1] = val2 -# values[id2] = val1 -# keys[id1] = key2 -# keys[id2] = key1 -# -# tf.region_end('Bitonic sort') -# if values is not None: -# return keys, values -# else: -# return keys -# -# #histogram radix sort +from __future__ import annotations + +from contextlib import ExitStack +from dataclasses import dataclass +from importlib import resources +from typing import Dict, Optional, Tuple + +import numpy as np + +from . import TensorFrost as tf + +__all__ = ["HistogramRadixSort", "radix_sort"] + +_TYPE_CODES: Dict[str, np.uint32] = { + "uint": np.uint32(0), + "int": np.uint32(1), + "float": np.uint32(2), +} + + +def _prepare_keys(keys: np.ndarray) -> Tuple[np.ndarray, np.dtype, str]: + array = np.asarray(keys) + if array.ndim != 1: + raise ValueError("radix_sort expects a 1D array of keys") + + dtype = array.dtype + if dtype == np.uint32: + return array, dtype, "uint" + + if dtype == np.int32: + return array, dtype, "int" + + if dtype == np.float32: + return array, dtype, "float" + + raise TypeError(f"Unsupported key dtype {dtype}; expected uint32, int32, or float32") + + +def _prepare_values(values: np.ndarray) -> Tuple[np.ndarray, np.dtype]: + array = np.asarray(values) + if array.ndim != 1: + raise ValueError("radix_sort expects a 1D array of values when provided") + + dtype = array.dtype + if dtype not in (np.uint32, np.int32, np.float32): + raise TypeError(f"Unsupported value dtype {dtype}; expected uint32, int32, or float32") + + return array, dtype + + +def _load_shader_source(filename: str) -> str: + package = f"{__package__}.shaders.radix" + try: + return resources.files(package).joinpath(filename).read_text(encoding="utf-8") # type: ignore[attr-defined] + except AttributeError: + return resources.read_text(package, filename) + + +@dataclass(frozen=True) +class _SorterKey: + bits_per_pass: int + block_size: int + group_size: int + + +class HistogramRadixSort: + """GPU histogram radix sort implemented with Slang + Vulkan.""" + + def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: int = 128) -> None: + if bits_per_pass <= 0: + raise ValueError("bits_per_pass must be positive") + if bits_per_pass > 8: + raise ValueError("bits_per_pass must be <= 8 to fit within MAX_HIST_SIZE") + if group_size != 128: + raise ValueError("This implementation currently requires group_size == 128") + if block_size <= 0 or block_size > 1024: + raise ValueError("block_size must be within (0, 1024]") + + self.bits_per_pass = bits_per_pass + self.block_size = block_size + self.group_size = group_size + + self._map_to_uint_program = tf.createComputeProgramFromSlang( + "radix_map_to_uint", + _load_shader_source("map_to_uint.slang"), + "csMapToUint", + ro_count=2, + rw_count=1, + ) + self._map_from_uint_program = tf.createComputeProgramFromSlang( + "radix_map_from_uint", + _load_shader_source("map_from_uint.slang"), + "csMapFromUint", + ro_count=2, + rw_count=1, + ) + + self._histogram_program = tf.createComputeProgramFromSlang( + "radix_histogram", + _load_shader_source("histogram.slang"), + "csHistogram", + ro_count=2, + rw_count=1, + ) + self._unpack_program = tf.createComputeProgramFromSlang( + "radix_unpack", + _load_shader_source("unpack.slang"), + "csUnpack", + ro_count=2, + rw_count=1, + ) + self._prefix_local_program = tf.createComputeProgramFromSlang( + "radix_prefix_local", + _load_shader_source("prefix_local.slang"), + "csPrefixLocal", + ro_count=2, + rw_count=2, + ) + self._prefix_blocks_program = tf.createComputeProgramFromSlang( + "radix_prefix_blocks", + _load_shader_source("prefix_block.slang"), + "csPrefixBlocks", + ro_count=2, + rw_count=1, + ) + self._prefix_accum_program = tf.createComputeProgramFromSlang( + "radix_prefix_accum", + _load_shader_source("prefix_accum.slang"), + "csPrefixAccumulate", + ro_count=2, + rw_count=1, + ) + self._bucket_scan_program = tf.createComputeProgramFromSlang( + "radix_bucket_scan", + _load_shader_source("bucket_scan.slang"), + "csBucketScan", + ro_count=2, + rw_count=1, + ) + self._scatter_program = tf.createComputeProgramFromSlang( + "radix_scatter", + _load_shader_source("scatter.slang"), + "csScatter", + ro_count=5, + rw_count=2, + ) + + self._dummy_values_buffer = tf.createBuffer(1, 4, False) + + def close(self) -> None: + for program in ( + self._map_to_uint_program, + self._map_from_uint_program, + self._histogram_program, + self._unpack_program, + self._prefix_local_program, + self._prefix_blocks_program, + self._prefix_accum_program, + self._bucket_scan_program, + self._scatter_program, + ): + if program is not None: + program.release() + if self._dummy_values_buffer is not None: + self._dummy_values_buffer.release() + + def sort( + self, + keys: np.ndarray, + values: Optional[np.ndarray] = None, + *, + max_bits: int = 32, + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + keys_array, key_dtype, key_kind = _prepare_keys(keys) + element_count = int(keys_array.shape[0]) + + if values is not None: + values_array, values_dtype = _prepare_values(values) + if values_array.shape[0] != element_count: + raise ValueError("values must have the same length as keys") + else: + values_array = None + values_dtype = None + + if element_count == 0: + empty_keys = keys_array.copy() + if values_array is None: + return empty_keys, None + return empty_keys, values_array.copy() + + max_bits = int(min(max_bits, 32)) + histogram_size = 1 << self.bits_per_pass + mask = np.uint32(histogram_size - 1) + + num_groups = max((element_count + self.group_size - 1) // self.group_size, 1) + block_count = max((num_groups + self.block_size - 1) // self.block_size, 1) + packed_count = (histogram_size + 3) // 4 + passes = max((max_bits + self.bits_per_pass - 1) // self.bits_per_pass, 1) + + params_array = np.zeros(8, dtype=np.uint32) + params_array[0] = np.uint32(element_count) + params_array[1] = np.uint32(histogram_size) + params_array[3] = mask + params_array[4] = np.uint32(num_groups) + params_array[5] = np.uint32(self.block_size) + params_array[6] = np.uint32(block_count) + params_array[7] = np.uint32(1 if values_array is not None else 0) + + map_params = np.zeros(4, dtype=np.uint32) + map_params[0] = np.uint32(element_count) + map_params[1] = _TYPE_CODES[key_kind] + + with ExitStack() as stack: + params_buffer = tf.createBuffer(params_array.size, 4, True) + stack.callback(params_buffer.release) + params_buffer.setData(params_array) + + map_buffer = tf.createBuffer(map_params.size, 4, True) + stack.callback(map_buffer.release) + map_buffer.setData(map_params) + + key_buffers = [tf.createBuffer(max(element_count, 1), 4, False) for _ in range(2)] + for buf in key_buffers: + stack.callback(buf.release) + key_buffers[0].setData(keys_array) + + if values_array is not None: + value_buffers = [tf.createBuffer(max(element_count, 1), 4, False) for _ in range(2)] + for buf in value_buffers: + stack.callback(buf.release) + value_buffers[0].setData(values_array) + else: + dummy = self._dummy_values_buffer + value_buffers = [dummy, dummy] + + packed_hist_buffer = tf.createBuffer(max(packed_count * num_groups, 1), 4, False) + group_hist_buffer = tf.createBuffer(max(histogram_size * num_groups, 1), 4, False) + prefix_buffer = tf.createBuffer(max(histogram_size * num_groups, 1), 4, False) + block_totals_buffer = tf.createBuffer(max(histogram_size * block_count, 1), 4, False) + block_prefix_buffer = tf.createBuffer(max(histogram_size * block_count, 1), 4, False) + bucket_scan_buffer = tf.createBuffer(max(histogram_size, 1), 4, False) + + for buf in ( + packed_hist_buffer, + group_hist_buffer, + prefix_buffer, + block_totals_buffer, + block_prefix_buffer, + bucket_scan_buffer, + ): + stack.callback(buf.release) + + self._map_to_uint_program.run( + [map_buffer, key_buffers[0]], + [key_buffers[1]], + element_count, + ) + + key_in = key_buffers[1] + key_out = key_buffers[0] + val_in, val_out = value_buffers + + for pass_index in range(passes): + params_array[2] = np.uint32(pass_index * self.bits_per_pass) + params_buffer.setData(params_array) + + dispatch_threads = num_groups * self.group_size + self._histogram_program.run( + [params_buffer, key_in], + [packed_hist_buffer], + dispatch_threads, + ) + + self._unpack_program.run( + [params_buffer, packed_hist_buffer], + [group_hist_buffer], + histogram_size * num_groups, + ) + + self._prefix_local_program.run( + [params_buffer, group_hist_buffer], + [prefix_buffer, block_totals_buffer], + histogram_size * block_count, + ) + + self._prefix_blocks_program.run( + [params_buffer, block_totals_buffer], + [block_prefix_buffer], + histogram_size, + ) + + self._prefix_accum_program.run( + [params_buffer, block_prefix_buffer], + [prefix_buffer], + histogram_size * block_count, + ) + + self._bucket_scan_program.run( + [params_buffer, prefix_buffer], + [bucket_scan_buffer], + histogram_size, + ) + + self._scatter_program.run( + [params_buffer, key_in, val_in, prefix_buffer, bucket_scan_buffer], + [key_out, val_out], + dispatch_threads, + ) + + key_in, key_out = key_out, key_in + if values_array is not None: + val_in, val_out = val_out, val_in + + self._map_from_uint_program.run( + [map_buffer, key_in], + [key_out], + element_count, + ) + + sorted_keys = key_out.getData(key_dtype, element_count) + if values_array is not None and values_dtype is not None: + sorted_values = val_in.getData(values_dtype, element_count) + else: + sorted_values = None + + return sorted_keys, sorted_values + + +_SORTER_CACHE: Dict[_SorterKey, HistogramRadixSort] = {} + + +def _get_sorter(bits_per_pass: int, block_size: int, group_size: int) -> HistogramRadixSort: + key = _SorterKey(bits_per_pass, block_size, group_size) + sorter = _SORTER_CACHE.get(key) + if sorter is None: + sorter = HistogramRadixSort(bits_per_pass=bits_per_pass, block_size=block_size, group_size=group_size) + _SORTER_CACHE[key] = sorter + return sorter + + +def radix_sort( + keys: np.ndarray, + values: Optional[np.ndarray] = None, + *, + bits_per_pass: int = 6, + max_bits: int = 32, + block_size: int = 64, + group_size: int = 128, +): + """Run the GPU histogram radix sort on the provided keys (and optional values). + + Returns the sorted keys, and when ``values`` is provided also returns the permuted values. + """ + + sorter = _get_sorter(bits_per_pass, block_size, group_size) + sorted_keys, sorted_values = sorter.sort(keys, values, max_bits=max_bits) + if values is None: + return sorted_keys + return sorted_keys, sorted_values # def radix(keys, values = None, bits_per_pass = 6, max_bits = 32): # def prefix_sum_grouped(A, axis = -1): # axis = len(A.shape) + axis if axis < 0 else axis diff --git a/examples/radix_sort/__main__.py b/examples/radix_sort/__main__.py new file mode 100644 index 00000000..39ea0288 --- /dev/null +++ b/examples/radix_sort/__main__.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import argparse +import time +from contextlib import ExitStack + +import numpy as np +import TensorFrost as tf + +try: + from .sort import HistogramRadixSort +except ImportError: + import sys + from pathlib import Path + + _CURRENT_DIR = Path(__file__).resolve().parent + if str(_CURRENT_DIR) not in sys.path: + sys.path.insert(0, str(_CURRENT_DIR)) + + from sort import HistogramRadixSort + + +def _select_backend() -> None: + if hasattr(tf, "initialize"): + backend = getattr(tf, "vulkan", None) + if backend is None: + raise RuntimeError("TensorFrost Vulkan backend is unavailable on this build") + tf.initialize(backend) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Histogram radix sort demo running on the Vulkan backend.") + parser.add_argument("--size", type=int, default=1 << 20, help="Number of key/value pairs to sort") + parser.add_argument("--bits", type=int, default=6, help="Bits processed per pass") + args = parser.parse_args() + + _select_backend() + + count = max(0, int(args.size)) + bits_per_pass = max(1, int(args.bits)) + rng = np.random.default_rng(1234) + + keys = rng.standard_normal(count, dtype=np.float32) + values = rng.integers(0, 1 << 31, size=count, dtype=np.uint32) + + with ExitStack() as stack: + sorter = HistogramRadixSort(bits_per_pass=bits_per_pass) + stack.callback(sorter.close) + start_time = time.perf_counter() + sorted_keys, sorted_values = sorter.sort(keys, values) + elapsed = time.perf_counter() - start_time + + if sorted_values is None: + sorted_values = np.empty_like(values) + + order = np.argsort(keys, kind="stable") + reference_keys = keys[order] + reference_values = values[order] + + key_match = np.allclose(sorted_keys, reference_keys, atol=0.0, rtol=0.0) + value_match = np.array_equal(sorted_values, reference_values) + + print(f"Sorted {count} elements with bits_per_pass={bits_per_pass}") + print(f"Sort elapsed: {elapsed * 1e3:.3f} ms ({elapsed:.6f} s)") + print(f"Keys match reference: {key_match}") + print(f"Values match reference: {value_match}") + if count: + preview = min(10, count) + print("First few sorted keys:") + print(sorted_keys[:preview]) + + +if __name__ == "__main__": + main() diff --git a/examples/radix_sort/shaders/__init__.py b/examples/radix_sort/shaders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/radix_sort/shaders/bucket_scan.slang b/examples/radix_sort/shaders/bucket_scan.slang new file mode 100644 index 00000000..e748d887 --- /dev/null +++ b/examples/radix_sort/shaders/bucket_scan.slang @@ -0,0 +1,43 @@ +static const uint GROUP_SIZE = 128; +static const uint QUARTER_SIZE = GROUP_SIZE / 4; +static const uint MAX_HIST_SIZE = 256; + +uint packedCount(uint histogramSize) +{ + return (histogramSize + 3u) >> 2; +} + +[[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); +[[vk::binding(1,0)]] StructuredBuffer GroupPrefix : register(t1, space0); +[[vk::binding(2,0)]] RWStructuredBuffer BucketScan : register(u2, space0); + +[numthreads(64, 1, 1)] +void csBucketScan(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + if (dispatchThreadID.x != 0) + return; + + uint histogramSize = Params[1]; + uint numGroups = Params[4]; + + if (histogramSize == 0) + return; + + if (numGroups == 0) + { + for (uint bit = 0; bit < histogramSize; ++bit) + { + BucketScan[bit] = 0; + } + return; + } + + uint base = (numGroups - 1u) * histogramSize; + uint sum = 0; + for (uint bit = 0; bit < histogramSize; ++bit) + { + uint total = GroupPrefix[base + bit]; + sum += total; + BucketScan[bit] = sum; + } +} diff --git a/examples/radix_sort/shaders/histogram.slang b/examples/radix_sort/shaders/histogram.slang new file mode 100644 index 00000000..5a7592ed --- /dev/null +++ b/examples/radix_sort/shaders/histogram.slang @@ -0,0 +1,59 @@ +static const uint GROUP_SIZE = 128; +static const uint QUARTER_SIZE = GROUP_SIZE / 4; +static const uint MAX_HIST_SIZE = 256; + +uint packedCount(uint histogramSize) +{ + return (histogramSize + 3u) >> 2; +} + +[[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); +[[vk::binding(1,0)]] StructuredBuffer KeysIn : register(t1, space0); +[[vk::binding(2,0)]] RWStructuredBuffer PackedHistogram : register(u2, space0); + +groupshared uint sPacked[MAX_HIST_SIZE / 4]; + +[numthreads(GROUP_SIZE, 1, 1)] +void csHistogram(uint3 groupID : SV_GroupID, uint3 localID : SV_GroupThreadID) +{ + uint elementCount = Params[0]; + uint histogramSize = Params[1]; + uint shift = Params[2]; + uint mask = Params[3]; + uint numGroups = Params[4]; + + if (histogramSize > MAX_HIST_SIZE) + return; + + uint group = groupID.x; + if (group >= numGroups) + return; + + uint packedCountLocal = packedCount(histogramSize); + uint lane = localID.x; + + for (uint idx = lane; idx < packedCountLocal; idx += GROUP_SIZE) + { + sPacked[idx] = 0; + } + GroupMemoryBarrierWithGroupSync(); + + uint globalIndex = group * GROUP_SIZE + lane; + if (globalIndex < elementCount) + { + uint key = KeysIn[globalIndex]; + uint bit = (key >> shift) & mask; + uint slot = bit >> 2; + if (slot < packedCountLocal) + { + uint laneShift = (bit & 3u) * 8u; + InterlockedAdd(sPacked[slot], 1u << laneShift); + } + } + GroupMemoryBarrierWithGroupSync(); + + for (uint idx = lane; idx < packedCountLocal; idx += GROUP_SIZE) + { + PackedHistogram[group * packedCountLocal + idx] = sPacked[idx]; + } +} diff --git a/examples/radix_sort/shaders/map_from_uint.slang b/examples/radix_sort/shaders/map_from_uint.slang new file mode 100644 index 00000000..86f819f9 --- /dev/null +++ b/examples/radix_sort/shaders/map_from_uint.slang @@ -0,0 +1,33 @@ +static const uint TYPE_UINT = 0u; +static const uint TYPE_INT = 1u; +static const uint TYPE_FLOAT = 2u; +static const uint SIGN_BIT = 0x80000000u; +static const uint FULL_MASK = 0xFFFFFFFFu; + +[[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); +[[vk::binding(1,0)]] StructuredBuffer Input : register(t1, space0); +[[vk::binding(2,0)]] RWStructuredBuffer Output : register(u2, space0); + +[numthreads(128, 1, 1)] +void csMapFromUint(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint index = dispatchThreadID.x; + uint count = Params[0]; + if (index >= count) + return; + + uint typeCode = Params[1]; + uint value = Input[index]; + + if (typeCode == TYPE_INT) + { + value ^= SIGN_BIT; + } + else if (typeCode == TYPE_FLOAT) + { + uint mask = ((value >> 31) == 0u) ? FULL_MASK : SIGN_BIT; + value ^= mask; + } + + Output[index] = value; +} diff --git a/examples/radix_sort/shaders/map_to_uint.slang b/examples/radix_sort/shaders/map_to_uint.slang new file mode 100644 index 00000000..b6be57f5 --- /dev/null +++ b/examples/radix_sort/shaders/map_to_uint.slang @@ -0,0 +1,33 @@ +static const uint TYPE_UINT = 0u; +static const uint TYPE_INT = 1u; +static const uint TYPE_FLOAT = 2u; +static const uint SIGN_BIT = 0x80000000u; +static const uint FULL_MASK = 0xFFFFFFFFu; + +[[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); +[[vk::binding(1,0)]] StructuredBuffer Input : register(t1, space0); +[[vk::binding(2,0)]] RWStructuredBuffer Output : register(u2, space0); + +[numthreads(128, 1, 1)] +void csMapToUint(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint index = dispatchThreadID.x; + uint count = Params[0]; + if (index >= count) + return; + + uint typeCode = Params[1]; + uint value = Input[index]; + + if (typeCode == TYPE_INT) + { + value ^= SIGN_BIT; + } + else if (typeCode == TYPE_FLOAT) + { + uint mask = ((value >> 31) == 1u) ? FULL_MASK : SIGN_BIT; + value ^= mask; + } + + Output[index] = value; +} diff --git a/examples/radix_sort/shaders/prefix_accum.slang b/examples/radix_sort/shaders/prefix_accum.slang new file mode 100644 index 00000000..81966007 --- /dev/null +++ b/examples/radix_sort/shaders/prefix_accum.slang @@ -0,0 +1,41 @@ +static const uint GROUP_SIZE = 128; +static const uint QUARTER_SIZE = GROUP_SIZE / 4; +static const uint MAX_HIST_SIZE = 256; + +uint packedCount(uint histogramSize) +{ + return (histogramSize + 3u) >> 2; +} + +[[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); +[[vk::binding(1,0)]] StructuredBuffer BlockPrefix : register(t1, space0); +[[vk::binding(2,0)]] RWStructuredBuffer GroupPrefix : register(u2, space0); + +[numthreads(64, 1, 1)] +void csPrefixAccumulate(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint histogramSize = Params[1]; + uint numGroups = Params[4]; + uint blockSize = Params[5]; + uint blockCount = Params[6]; + + uint totalThreads = blockCount * histogramSize; + uint index = dispatchThreadID.x; + if (index >= totalThreads) + return; + + uint block = index / histogramSize; + uint bit = index % histogramSize; + + uint startGroup = block * blockSize; + if (startGroup >= numGroups) + return; + + uint endGroup = min(startGroup + blockSize, numGroups); + uint prefix = BlockPrefix[index]; + for (uint g = startGroup; g < endGroup; ++g) + { + uint idx = g * histogramSize + bit; + GroupPrefix[idx] += prefix; + } +} diff --git a/examples/radix_sort/shaders/prefix_block.slang b/examples/radix_sort/shaders/prefix_block.slang new file mode 100644 index 00000000..3a4909b0 --- /dev/null +++ b/examples/radix_sort/shaders/prefix_block.slang @@ -0,0 +1,31 @@ +static const uint GROUP_SIZE = 128; +static const uint QUARTER_SIZE = GROUP_SIZE / 4; +static const uint MAX_HIST_SIZE = 256; + +uint packedCount(uint histogramSize) +{ + return (histogramSize + 3u) >> 2; +} + +[[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); +[[vk::binding(1,0)]] StructuredBuffer BlockTotals : register(t1, space0); +[[vk::binding(2,0)]] RWStructuredBuffer BlockPrefix : register(u2, space0); + +[numthreads(64, 1, 1)] +void csPrefixBlocks(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint histogramSize = Params[1]; + uint blockCount = Params[6]; + uint bucket = dispatchThreadID.x; + if (bucket >= histogramSize) + return; + + uint sum = 0; + for (uint block = 0; block < blockCount; ++block) + { + uint idx = block * histogramSize + bucket; + uint total = BlockTotals[idx]; + BlockPrefix[idx] = sum; + sum += total; + } +} diff --git a/examples/radix_sort/shaders/prefix_local.slang b/examples/radix_sort/shaders/prefix_local.slang new file mode 100644 index 00000000..5e70d6dc --- /dev/null +++ b/examples/radix_sort/shaders/prefix_local.slang @@ -0,0 +1,47 @@ +static const uint GROUP_SIZE = 128; +static const uint QUARTER_SIZE = GROUP_SIZE / 4; +static const uint MAX_HIST_SIZE = 256; + +uint packedCount(uint histogramSize) +{ + return (histogramSize + 3u) >> 2; +} + +[[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); +[[vk::binding(1,0)]] StructuredBuffer GroupHistogram : register(t1, space0); +[[vk::binding(2,0)]] RWStructuredBuffer GroupPrefix : register(u2, space0); +[[vk::binding(3,0)]] RWStructuredBuffer BlockTotals : register(u3, space0); + +[numthreads(64, 1, 1)] +void csPrefixLocal(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint histogramSize = Params[1]; + uint numGroups = Params[4]; + uint blockSize = Params[5]; + uint blockCount = Params[6]; + + uint totalThreads = blockCount * histogramSize; + uint index = dispatchThreadID.x; + if (index >= totalThreads) + return; + + uint block = index / histogramSize; + uint bit = index % histogramSize; + + uint startGroup = block * blockSize; + if (startGroup >= numGroups) + { + BlockTotals[index] = 0; + return; + } + + uint endGroup = min(startGroup + blockSize, numGroups); + uint sum = 0; + for (uint g = startGroup; g < endGroup; ++g) + { + uint idx = g * histogramSize + bit; + sum += GroupHistogram[idx]; + GroupPrefix[idx] = sum; + } + BlockTotals[index] = sum; +} diff --git a/examples/radix_sort/shaders/scatter.slang b/examples/radix_sort/shaders/scatter.slang new file mode 100644 index 00000000..3ff2afe3 --- /dev/null +++ b/examples/radix_sort/shaders/scatter.slang @@ -0,0 +1,104 @@ +static const uint GROUP_SIZE = 128; +static const uint QUARTER_SIZE = GROUP_SIZE / 4; +static const uint MAX_HIST_SIZE = 256; + +uint packedCount(uint histogramSize) +{ + return (histogramSize + 3u) >> 2; +} + +[[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); +[[vk::binding(1,0)]] StructuredBuffer KeysIn : register(t1, space0); +[[vk::binding(2,0)]] StructuredBuffer ValuesIn : register(t2, space0); +[[vk::binding(3,0)]] StructuredBuffer GroupPrefix : register(t3, space0); +[[vk::binding(4,0)]] StructuredBuffer BucketScan : register(t4, space0); +[[vk::binding(5,0)]] RWStructuredBuffer KeysOut : register(u5, space0); +[[vk::binding(6,0)]] RWStructuredBuffer ValuesOut : register(u6, space0); + +groupshared uint tempBits[GROUP_SIZE]; +groupshared uint halfCount[MAX_HIST_SIZE]; + +[numthreads(GROUP_SIZE, 1, 1)] +void csScatter(uint3 groupID : SV_GroupID, uint3 localID : SV_GroupThreadID) +{ + uint elementCount = Params[0]; + uint histogramSize = Params[1]; + uint shift = Params[2]; + uint mask = Params[3]; + uint numGroups = Params[4]; + uint hasValues = Params[7]; + + if (histogramSize > MAX_HIST_SIZE) + return; + + uint group = groupID.x; + if (group >= numGroups) + return; + + uint lane = localID.x; + + for (uint idx = lane; idx < histogramSize; idx += GROUP_SIZE) + { + halfCount[idx] = 0; + } + GroupMemoryBarrierWithGroupSync(); + + uint globalIndex = group * GROUP_SIZE + lane; + bool active = (globalIndex < elementCount); + uint bit = 0; + uint key = 0; + uint value = 0; + + if (active) + { + key = KeysIn[globalIndex]; + bit = (key >> shift) & mask; + tempBits[lane] = bit; + if (hasValues != 0) + value = ValuesIn[globalIndex]; + } + else + { + tempBits[lane] = 0; + } + GroupMemoryBarrierWithGroupSync(); + + uint quarterIndex = lane / QUARTER_SIZE; + if (active && quarterIndex < 3) + { + uint inc = 0; + if (quarterIndex < 1) inc |= 1u; + if (quarterIndex < 2) inc |= (1u << 8); + if (quarterIndex < 3) inc |= (1u << 16); + InterlockedAdd(halfCount[bit], inc); + } + GroupMemoryBarrierWithGroupSync(); + + if (!active) + return; + + uint prevBucket = (bit == 0) ? 0 : BucketScan[bit - 1u]; + uint prevGroup = (group == 0) ? 0 : GroupPrefix[(group - 1u) * histogramSize + bit]; + + uint quarterOffset = 0; + if (quarterIndex > 0) + { + uint packed = halfCount[bit]; + uint shiftBytes = (quarterIndex - 1u) * 8u; + quarterOffset = (packed >> shiftBytes) & 0xFFu; + } + + uint begin = quarterIndex * QUARTER_SIZE; + uint localCount = 0; + for (uint t = begin; t < lane; ++t) + { + if (tempBits[t] == bit) + ++localCount; + } + + uint totalOffset = prevBucket + prevGroup + quarterOffset + localCount; + KeysOut[totalOffset] = key; + + if (hasValues != 0) + ValuesOut[totalOffset] = value; +} diff --git a/examples/radix_sort/shaders/unpack.slang b/examples/radix_sort/shaders/unpack.slang new file mode 100644 index 00000000..ee53179b --- /dev/null +++ b/examples/radix_sort/shaders/unpack.slang @@ -0,0 +1,34 @@ +static const uint GROUP_SIZE = 128; +static const uint QUARTER_SIZE = GROUP_SIZE / 4; +static const uint MAX_HIST_SIZE = 256; + +uint packedCount(uint histogramSize) +{ + return (histogramSize + 3u) >> 2; +} + +[[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); +[[vk::binding(1,0)]] StructuredBuffer PackedHistogram : register(t1, space0); +[[vk::binding(2,0)]] RWStructuredBuffer GroupHistogram : register(u2, space0); + +[numthreads(64, 1, 1)] +void csUnpack(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint histogramSize = Params[1]; + uint numGroups = Params[4]; + uint packedCountLocal = packedCount(histogramSize); + + uint total = numGroups * histogramSize; + uint index = dispatchThreadID.x; + if (index >= total) + return; + + uint group = index / histogramSize; + uint bit = index % histogramSize; + + uint packedIndex = group * packedCountLocal + (bit >> 2); + uint packed = PackedHistogram[packedIndex]; + uint shift = (bit & 3u) * 8u; + uint count = (packed >> shift) & 0xFFu; + GroupHistogram[index] = count; +} diff --git a/examples/radix_sort/sort.py b/examples/radix_sort/sort.py new file mode 100644 index 00000000..68569fb7 --- /dev/null +++ b/examples/radix_sort/sort.py @@ -0,0 +1,359 @@ +from __future__ import annotations + +from contextlib import ExitStack +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Optional, Tuple + +import numpy as np + +import TensorFrost as tf + +__all__ = ["HistogramRadixSort", "radix_sort"] + +_TYPE_CODES: Dict[str, np.uint32] = { + "uint": np.uint32(0), + "int": np.uint32(1), + "float": np.uint32(2), +} + + +def _prepare_keys(keys: np.ndarray) -> Tuple[np.ndarray, np.dtype, str]: + array = np.asarray(keys) + if array.ndim != 1: + raise ValueError("radix_sort expects a 1D array of keys") + + dtype = array.dtype + if dtype == np.uint32: + return array, dtype, "uint" + + if dtype == np.int32: + return array, dtype, "int" + + if dtype == np.float32: + return array, dtype, "float" + + raise TypeError(f"Unsupported key dtype {dtype}; expected uint32, int32, or float32") + + +def _prepare_values(values: np.ndarray) -> Tuple[np.ndarray, np.dtype]: + array = np.asarray(values) + if array.ndim != 1: + raise ValueError("radix_sort expects a 1D array of values when provided") + + dtype = array.dtype + if dtype not in (np.uint32, np.int32, np.float32): + raise TypeError(f"Unsupported value dtype {dtype}; expected uint32, int32, or float32") + + return array, dtype + + +_SHADER_DIR = Path(__file__).resolve().parent / "shaders" + + +def _load_shader_source(filename: str) -> str: + shader_path = _SHADER_DIR / filename + return shader_path.read_text(encoding="utf-8") + + +@dataclass(frozen=True) +class _SorterKey: + bits_per_pass: int + block_size: int + group_size: int + + +class HistogramRadixSort: + """GPU histogram radix sort implemented with Slang + Vulkan.""" + + def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: int = 128) -> None: + if bits_per_pass <= 0: + raise ValueError("bits_per_pass must be positive") + if bits_per_pass > 8: + raise ValueError("bits_per_pass must be <= 8 to fit within MAX_HIST_SIZE") + if group_size != 128: + raise ValueError("This implementation currently requires group_size == 128") + if block_size <= 0 or block_size > 1024: + raise ValueError("block_size must be within (0, 1024]") + + self.bits_per_pass = bits_per_pass + self.block_size = block_size + self.group_size = group_size + + self._map_to_uint_program = tf.createComputeProgramFromSlang( + "radix_map_to_uint", + _load_shader_source("map_to_uint.slang"), + "csMapToUint", + ro_count=2, + rw_count=1, + ) + self._map_from_uint_program = tf.createComputeProgramFromSlang( + "radix_map_from_uint", + _load_shader_source("map_from_uint.slang"), + "csMapFromUint", + ro_count=2, + rw_count=1, + ) + + self._histogram_program = tf.createComputeProgramFromSlang( + "radix_histogram", + _load_shader_source("histogram.slang"), + "csHistogram", + ro_count=2, + rw_count=1, + ) + self._unpack_program = tf.createComputeProgramFromSlang( + "radix_unpack", + _load_shader_source("unpack.slang"), + "csUnpack", + ro_count=2, + rw_count=1, + ) + self._prefix_local_program = tf.createComputeProgramFromSlang( + "radix_prefix_local", + _load_shader_source("prefix_local.slang"), + "csPrefixLocal", + ro_count=2, + rw_count=2, + ) + self._prefix_blocks_program = tf.createComputeProgramFromSlang( + "radix_prefix_blocks", + _load_shader_source("prefix_block.slang"), + "csPrefixBlocks", + ro_count=2, + rw_count=1, + ) + self._prefix_accum_program = tf.createComputeProgramFromSlang( + "radix_prefix_accum", + _load_shader_source("prefix_accum.slang"), + "csPrefixAccumulate", + ro_count=2, + rw_count=1, + ) + self._bucket_scan_program = tf.createComputeProgramFromSlang( + "radix_bucket_scan", + _load_shader_source("bucket_scan.slang"), + "csBucketScan", + ro_count=2, + rw_count=1, + ) + self._scatter_program = tf.createComputeProgramFromSlang( + "radix_scatter", + _load_shader_source("scatter.slang"), + "csScatter", + ro_count=5, + rw_count=2, + ) + + self._dummy_values_buffer = tf.createBuffer(1, 4, False) + + def close(self) -> None: + for program in ( + self._map_to_uint_program, + self._map_from_uint_program, + self._histogram_program, + self._unpack_program, + self._prefix_local_program, + self._prefix_blocks_program, + self._prefix_accum_program, + self._bucket_scan_program, + self._scatter_program, + ): + if program is not None: + program.release() + if self._dummy_values_buffer is not None: + self._dummy_values_buffer.release() + + def sort( + self, + keys: np.ndarray, + values: Optional[np.ndarray] = None, + *, + max_bits: int = 32, + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + keys_array, key_dtype, key_kind = _prepare_keys(keys) + element_count = int(keys_array.shape[0]) + + if values is not None: + values_array, values_dtype = _prepare_values(values) + if values_array.shape[0] != element_count: + raise ValueError("values must have the same length as keys") + else: + values_array = None + values_dtype = None + + if element_count == 0: + empty_keys = keys_array.copy() + if values_array is None: + return empty_keys, None + return empty_keys, values_array.copy() + + max_bits = int(min(max_bits, 32)) + histogram_size = 1 << self.bits_per_pass + mask = np.uint32(histogram_size - 1) + + num_groups = max((element_count + self.group_size - 1) // self.group_size, 1) + block_count = max((num_groups + self.block_size - 1) // self.block_size, 1) + packed_count = (histogram_size + 3) // 4 + passes = max((max_bits + self.bits_per_pass - 1) // self.bits_per_pass, 1) + + params_array = np.zeros(8, dtype=np.uint32) + params_array[0] = np.uint32(element_count) + params_array[1] = np.uint32(histogram_size) + params_array[3] = mask + params_array[4] = np.uint32(num_groups) + params_array[5] = np.uint32(self.block_size) + params_array[6] = np.uint32(block_count) + params_array[7] = np.uint32(1 if values_array is not None else 0) + + map_params = np.zeros(4, dtype=np.uint32) + map_params[0] = np.uint32(element_count) + map_params[1] = _TYPE_CODES[key_kind] + + with ExitStack() as stack: + params_buffer = tf.createBuffer(params_array.size, 4, True) + stack.callback(params_buffer.release) + params_buffer.setData(params_array) + + map_buffer = tf.createBuffer(map_params.size, 4, True) + stack.callback(map_buffer.release) + map_buffer.setData(map_params) + + key_buffers = [tf.createBuffer(max(element_count, 1), 4, False) for _ in range(2)] + for buf in key_buffers: + stack.callback(buf.release) + key_buffers[0].setData(keys_array) + + if values_array is not None: + value_buffers = [tf.createBuffer(max(element_count, 1), 4, False) for _ in range(2)] + for buf in value_buffers: + stack.callback(buf.release) + value_buffers[0].setData(values_array) + else: + dummy = self._dummy_values_buffer + value_buffers = [dummy, dummy] + + packed_hist_buffer = tf.createBuffer(max(packed_count * num_groups, 1), 4, False) + group_hist_buffer = tf.createBuffer(max(histogram_size * num_groups, 1), 4, False) + prefix_buffer = tf.createBuffer(max(histogram_size * num_groups, 1), 4, False) + block_totals_buffer = tf.createBuffer(max(histogram_size * block_count, 1), 4, False) + block_prefix_buffer = tf.createBuffer(max(histogram_size * block_count, 1), 4, False) + bucket_scan_buffer = tf.createBuffer(max(histogram_size, 1), 4, False) + + for buf in ( + packed_hist_buffer, + group_hist_buffer, + prefix_buffer, + block_totals_buffer, + block_prefix_buffer, + bucket_scan_buffer, + ): + stack.callback(buf.release) + + self._map_to_uint_program.run( + [map_buffer, key_buffers[0]], + [key_buffers[1]], + element_count, + ) + + key_in = key_buffers[1] + key_out = key_buffers[0] + val_in, val_out = value_buffers + + for pass_index in range(passes): + params_array[2] = np.uint32(pass_index * self.bits_per_pass) + params_buffer.setData(params_array) + + dispatch_threads = num_groups * self.group_size + self._histogram_program.run( + [params_buffer, key_in], + [packed_hist_buffer], + dispatch_threads, + ) + + self._unpack_program.run( + [params_buffer, packed_hist_buffer], + [group_hist_buffer], + histogram_size * num_groups, + ) + + self._prefix_local_program.run( + [params_buffer, group_hist_buffer], + [prefix_buffer, block_totals_buffer], + histogram_size * block_count, + ) + + self._prefix_blocks_program.run( + [params_buffer, block_totals_buffer], + [block_prefix_buffer], + histogram_size, + ) + + self._prefix_accum_program.run( + [params_buffer, block_prefix_buffer], + [prefix_buffer], + histogram_size * block_count, + ) + + self._bucket_scan_program.run( + [params_buffer, prefix_buffer], + [bucket_scan_buffer], + histogram_size, + ) + + self._scatter_program.run( + [params_buffer, key_in, val_in, prefix_buffer, bucket_scan_buffer], + [key_out, val_out], + dispatch_threads, + ) + + key_in, key_out = key_out, key_in + if values_array is not None: + val_in, val_out = val_out, val_in + + self._map_from_uint_program.run( + [map_buffer, key_in], + [key_out], + element_count, + ) + + sorted_keys = key_out.getData(key_dtype, element_count) + if values_array is not None and values_dtype is not None: + sorted_values = val_in.getData(values_dtype, element_count) + else: + sorted_values = None + + return sorted_keys, sorted_values + + +_SORTER_CACHE: Dict[_SorterKey, HistogramRadixSort] = {} + + +def _get_sorter(bits_per_pass: int, block_size: int, group_size: int) -> HistogramRadixSort: + key = _SorterKey(bits_per_pass, block_size, group_size) + sorter = _SORTER_CACHE.get(key) + if sorter is None: + sorter = HistogramRadixSort(bits_per_pass=bits_per_pass, block_size=block_size, group_size=group_size) + _SORTER_CACHE[key] = sorter + return sorter + + +def radix_sort( + keys: np.ndarray, + values: Optional[np.ndarray] = None, + *, + bits_per_pass: int = 6, + max_bits: int = 32, + block_size: int = 64, + group_size: int = 128, +): + """Run the GPU histogram radix sort on the provided keys (and optional values). + + Returns the sorted keys, and when ``values`` is provided also returns the permuted values. + """ + + sorter = _get_sorter(bits_per_pass, block_size, group_size) + sorted_keys, sorted_values = sorter.sort(keys, values, max_bits=max_bits) + if values is None: + return sorted_keys + return sorted_keys, sorted_values \ No newline at end of file From 5ab306352467865012aa53cf6792371abbd1f7f0 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Fri, 7 Nov 2025 05:09:35 +0100 Subject: [PATCH 31/44] Add renderdoc integration --- README.md | 3 +- TensorFrost/Backend/CMakeLists.txt | 1 + .../Backend/include/Backend/RenderDoc.h | 10 ++ TensorFrost/Backend/src/RenderDoc.cpp | 123 ++++++++++++++++++ TensorFrost/Backend/src/Vulkan.cpp | 16 +++ TensorFrost/PybindModule.cpp | 13 ++ examples/radix_sort/__main__.py | 13 ++ tests/renderdoc_test.py | 22 ++++ 8 files changed, 200 insertions(+), 1 deletion(-) create mode 100644 TensorFrost/Backend/include/Backend/RenderDoc.h create mode 100644 TensorFrost/Backend/src/RenderDoc.cpp create mode 100644 tests/renderdoc_test.py diff --git a/README.md b/README.md index 4f0798c9..e7b6787e 100644 --- a/README.md +++ b/README.md @@ -689,11 +689,12 @@ You can also specify the clipping type for the gradients, by default the value o For debugging convenience there are 2 function types that you can call inside a tensor program: ```python +tf.renderdoc_is_available() tf.renderdoc_start_capture() tf.renderdoc_end_capture() ``` -These functions will start and end a RenderDoc capture, only if python is started from the RenderDoc GUI. This is useful for debugging the OpenGL backend, as it allows you to inspect compiled kernel execution, its code and buffers. +These functions will start and end a RenderDoc capture, only if python is started from the RenderDoc GUI. Call `tf.renderdoc_is_available()` first to check whether RenderDoc is attached so you can skip capture logic when it isn't. This is useful for debugging the OpenGL backend, as it allows you to inspect compiled kernel execution, its code and buffers. ```python tf.region_begin('Region name') diff --git a/TensorFrost/Backend/CMakeLists.txt b/TensorFrost/Backend/CMakeLists.txt index 4df5c05e..58a93ef9 100644 --- a/TensorFrost/Backend/CMakeLists.txt +++ b/TensorFrost/Backend/CMakeLists.txt @@ -16,6 +16,7 @@ target_include_directories(TensorFrostBackend PUBLIC ${TF_BACKEND_INC_DIR} $ENV{VULKAN_SDK}/Include + ${CMAKE_SOURCE_DIR}/external/renderdoc ${CMAKE_SOURCE_DIR}/external/imgui ${CMAKE_SOURCE_DIR}/external/imgui/backends) diff --git a/TensorFrost/Backend/include/Backend/RenderDoc.h b/TensorFrost/Backend/include/Backend/RenderDoc.h new file mode 100644 index 00000000..c3103c24 --- /dev/null +++ b/TensorFrost/Backend/include/Backend/RenderDoc.h @@ -0,0 +1,10 @@ +#pragma once + +namespace TensorFrost { + +void StartRenderDocCapture(); +void EndRenderDocCapture(); + +bool IsRenderDocAvailable(); + +} // namespace TensorFrost diff --git a/TensorFrost/Backend/src/RenderDoc.cpp b/TensorFrost/Backend/src/RenderDoc.cpp new file mode 100644 index 00000000..c676b5e9 --- /dev/null +++ b/TensorFrost/Backend/src/RenderDoc.cpp @@ -0,0 +1,123 @@ +#include "Backend/RenderDoc.h" + +#include + +#include +#include +#include +#include +#include + +#include "Backend/Vulkan.h" + +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#include +#else +#include +#endif + +namespace TensorFrost { +namespace { +RENDERDOC_API_1_4_2* gRenderDocApi = nullptr; + +void LoadRenderDoc() +{ + static bool loggedUnavailable = false; + + if (gRenderDocApi) { + return; + } + +#if defined(_WIN32) + const char* moduleNames[] = {"renderdoc.dll", "renderdoccmd.dll"}; + + for (const char* moduleName : moduleNames) { + HMODULE mod = GetModuleHandleA(moduleName); + if (!mod) { + continue; + } + + auto getApi = reinterpret_cast(GetProcAddress(mod, "RENDERDOC_GetAPI")); + if (getApi && + getApi(eRENDERDOC_API_Version_1_4_2, reinterpret_cast(&gRenderDocApi)) == 1) { + std::cout << "RenderDoc API loaded from " << moduleName << std::endl; + loggedUnavailable = false; + break; + } + } +#else + if (void* mod = dlopen("librenderdoc.so", RTLD_NOW | RTLD_NOLOAD)) { + auto getApi = reinterpret_cast(dlsym(mod, "RENDERDOC_GetAPI")); + if (getApi && getApi(eRENDERDOC_API_Version_1_4_2, reinterpret_cast(&gRenderDocApi)) == 1) { + std::cout << "RenderDoc API loaded" << std::endl; + loggedUnavailable = false; + dlclose(mod); + return; + } + dlclose(mod); + } +#endif + + if (!gRenderDocApi && !loggedUnavailable) { + std::cout << "RenderDoc API not available" << std::endl; + loggedUnavailable = true; + } +} +} // namespace + +bool IsRenderDocAvailable() +{ + LoadRenderDoc(); + return gRenderDocApi != nullptr; +} + +void StartRenderDocCapture() +{ + if (!IsRenderDocAvailable()) { + std::cout << "RenderDoc not available; start capture skipped" << std::endl; + return; + } + + RENDERDOC_DevicePointer deviceHandle = nullptr; + try { + auto& ctx = getVulkanContext(); + VkDevice vkDevice = ctx.device; + deviceHandle = reinterpret_cast(vkDevice); + } catch (const std::exception& e) { + std::cout << "RenderDoc capture start warning: failed to get Vulkan device (" << e.what() << ")" << std::endl; + } + + gRenderDocApi->StartFrameCapture(deviceHandle, nullptr); + if (gRenderDocApi->IsFrameCapturing && gRenderDocApi->IsFrameCapturing() == 1) { + std::cout << "RenderDoc capture start requested" << std::endl; + } else { + std::cout << "RenderDoc capture start did not begin (IsFrameCapturing=0)" << std::endl; + } +} + +void EndRenderDocCapture() +{ + if (!IsRenderDocAvailable()) { + std::cout << "RenderDoc not available; end capture skipped" << std::endl; + return; + } + + RENDERDOC_DevicePointer deviceHandle = nullptr; + try { + auto& ctx = getVulkanContext(); + VkDevice vkDevice = ctx.device; + deviceHandle = reinterpret_cast(vkDevice); + } catch (const std::exception& e) { + std::cout << "RenderDoc capture end warning: failed to get Vulkan device (" << e.what() << ")" << std::endl; + } + + const uint32_t result = gRenderDocApi->EndFrameCapture(deviceHandle, nullptr); + if (result == 1) { + std::cout << "RenderDoc capture end requested" << std::endl; + } else { + std::cout << "RenderDoc capture end failed (" << result << ")" << std::endl; + } +} + +} // namespace TensorFrost diff --git a/TensorFrost/Backend/src/Vulkan.cpp b/TensorFrost/Backend/src/Vulkan.cpp index be5a3c8c..76bfbabb 100644 --- a/TensorFrost/Backend/src/Vulkan.cpp +++ b/TensorFrost/Backend/src/Vulkan.cpp @@ -3,6 +3,7 @@ VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE #include #include #include +#include #include namespace { @@ -338,6 +339,10 @@ static std::vector compileGLSLToSpirv(const std::string& source) { shaderc::CompileOptions opts; opts.SetTargetEnvironment(shaderc_target_env_vulkan, shaderc_env_version_vulkan_1_1); +#if defined(_RELWITHDEBINFO) + opts.SetGenerateDebugInfo(); + opts.SetOptimizationLevel(shaderc_optimization_level_zero); +#endif shaderc::SpvCompilationResult result = compiler.CompileGlslToSpv(source, shaderc_compute_shader, "shader", opts); if (result.GetCompilationStatus() != shaderc_compilation_status_success) { @@ -359,6 +364,17 @@ std::vector compileSlangToSpirv(const char* moduleName, slang::SessionDesc sd{}; sd.targets = &tgt; sd.targetCount = 1; +#if defined(_RELWITHDEBINFO) + std::array optionEntries{}; + optionEntries[0].name = slang::CompilerOptionName::DebugInformation; + optionEntries[0].value.kind = slang::CompilerOptionValueKind::Int; + optionEntries[0].value.intValue0 = SLANG_DEBUG_INFO_LEVEL_STANDARD; + optionEntries[1].name = slang::CompilerOptionName::Optimization; + optionEntries[1].value.kind = slang::CompilerOptionValueKind::Int; + optionEntries[1].value.intValue0 = SLANG_OPTIMIZATION_LEVEL_NONE; + sd.compilerOptionEntries = optionEntries.data(); + sd.compilerOptionEntryCount = static_cast(optionEntries.size()); +#endif Slang::ComPtr session; global->createSession(sd, session.writeRef()); diff --git a/TensorFrost/PybindModule.cpp b/TensorFrost/PybindModule.cpp index 253d1d23..bc27ffe9 100644 --- a/TensorFrost/PybindModule.cpp +++ b/TensorFrost/PybindModule.cpp @@ -10,6 +10,7 @@ #include #include "TensorFrost.h" +#include "Backend/RenderDoc.h" #include "Definitions/VulkanBindings.h" namespace py = pybind11; @@ -175,6 +176,18 @@ PYBIND11_MODULE(TensorFrost, m) { py::print(program.DebugPrint()); VulkanDefinitions(m); + + m.def("renderdoc_is_available", []() { + return IsRenderDocAvailable(); + }, "Check if RenderDoc is attached to this process"); + + m.def("renderdoc_start_capture", []() { + StartRenderDocCapture(); + }, "Start a RenderDoc capture"); + + m.def("renderdoc_end_capture", []() { + EndRenderDocCapture(); + }, "End the current RenderDoc capture"); // VulkanContext ctx; // // const size_t N = 1024; diff --git a/examples/radix_sort/__main__.py b/examples/radix_sort/__main__.py index 39ea0288..466e8a6e 100644 --- a/examples/radix_sort/__main__.py +++ b/examples/radix_sort/__main__.py @@ -44,6 +44,19 @@ def main() -> None: values = rng.integers(0, 1 << 31, size=count, dtype=np.uint32) with ExitStack() as stack: + renderdoc_is_available = getattr(tf, "renderdoc_is_available", None) + renderdoc_start = getattr(tf, "renderdoc_start_capture", None) + renderdoc_end = getattr(tf, "renderdoc_end_capture", None) + if ( + callable(renderdoc_is_available) + and renderdoc_is_available() + and callable(renderdoc_start) + and callable(renderdoc_end) + ): + renderdoc_start() + print("RenderDoc capture started") + stack.callback(renderdoc_end) + sorter = HistogramRadixSort(bits_per_pass=bits_per_pass) stack.callback(sorter.close) start_time = time.perf_counter() diff --git a/tests/renderdoc_test.py b/tests/renderdoc_test.py new file mode 100644 index 00000000..187ef6d8 --- /dev/null +++ b/tests/renderdoc_test.py @@ -0,0 +1,22 @@ +import unittest + +import TensorFrost as tf + + +class RenderDocBindingTest(unittest.TestCase): + def test_renderdoc_functions_exist(self): + self.assertTrue(hasattr(tf, "renderdoc_start_capture")) + self.assertTrue(hasattr(tf, "renderdoc_end_capture")) + self.assertTrue(hasattr(tf, "renderdoc_is_available")) + + def test_renderdoc_capture_calls(self): + # Calls shouldn't raise even when RenderDoc isn't attached. + tf.renderdoc_start_capture() + tf.renderdoc_end_capture() + + def test_renderdoc_available_returns_bool(self): + self.assertIsInstance(tf.renderdoc_is_available(), bool) + + +if __name__ == "__main__": + unittest.main() From b1f5555f775c8c6f1ba54eefd1a403442a2a67d2 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Fri, 7 Nov 2025 18:28:08 +0100 Subject: [PATCH 32/44] Have a replay ui show option for renderdoc --- Python/TensorFrost/sort.py | 6 ++- .../Backend/include/Backend/RenderDoc.h | 5 +- TensorFrost/Backend/src/RenderDoc.cpp | 50 ++++++++++++++++--- TensorFrost/PybindModule.cpp | 10 ++-- examples/radix_sort/__main__.py | 16 ++++-- examples/radix_sort/shaders/scatter.slang | 21 +++----- examples/radix_sort/sort.py | 6 ++- tests/renderdoc_test.py | 5 +- 8 files changed, 83 insertions(+), 36 deletions(-) diff --git a/Python/TensorFrost/sort.py b/Python/TensorFrost/sort.py index 277d130e..3db1efee 100644 --- a/Python/TensorFrost/sort.py +++ b/Python/TensorFrost/sort.py @@ -79,6 +79,7 @@ def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: self.bits_per_pass = bits_per_pass self.block_size = block_size self.group_size = group_size + self.histogram_size = 1 << bits_per_pass self._map_to_uint_program = tf.createComputeProgramFromSlang( "radix_map_to_uint", @@ -137,9 +138,10 @@ def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: ro_count=2, rw_count=1, ) + scatter_source = f"#define TF_HISTOGRAM_SIZE {self.histogram_size}u\n" + _load_shader_source("scatter.slang") self._scatter_program = tf.createComputeProgramFromSlang( "radix_scatter", - _load_shader_source("scatter.slang"), + scatter_source, "csScatter", ro_count=5, rw_count=2, @@ -189,7 +191,7 @@ def sort( return empty_keys, values_array.copy() max_bits = int(min(max_bits, 32)) - histogram_size = 1 << self.bits_per_pass + histogram_size = self.histogram_size mask = np.uint32(histogram_size - 1) num_groups = max((element_count + self.group_size - 1) // self.group_size, 1) diff --git a/TensorFrost/Backend/include/Backend/RenderDoc.h b/TensorFrost/Backend/include/Backend/RenderDoc.h index c3103c24..a89e14df 100644 --- a/TensorFrost/Backend/include/Backend/RenderDoc.h +++ b/TensorFrost/Backend/include/Backend/RenderDoc.h @@ -1,10 +1,11 @@ #pragma once +#include + namespace TensorFrost { void StartRenderDocCapture(); -void EndRenderDocCapture(); - +std::string EndRenderDocCapture(bool launchReplayUI = false); bool IsRenderDocAvailable(); } // namespace TensorFrost diff --git a/TensorFrost/Backend/src/RenderDoc.cpp b/TensorFrost/Backend/src/RenderDoc.cpp index c676b5e9..0b5fbc7f 100644 --- a/TensorFrost/Backend/src/RenderDoc.cpp +++ b/TensorFrost/Backend/src/RenderDoc.cpp @@ -82,10 +82,10 @@ void StartRenderDocCapture() RENDERDOC_DevicePointer deviceHandle = nullptr; try { auto& ctx = getVulkanContext(); - VkDevice vkDevice = ctx.device; - deviceHandle = reinterpret_cast(vkDevice); + VkInstance vkInstance = static_cast(ctx.instance); + deviceHandle = RENDERDOC_DEVICEPOINTER_FROM_VKINSTANCE(vkInstance); } catch (const std::exception& e) { - std::cout << "RenderDoc capture start warning: failed to get Vulkan device (" << e.what() << ")" << std::endl; + std::cout << "RenderDoc capture start warning: failed to get Vulkan instance (" << e.what() << ")" << std::endl; } gRenderDocApi->StartFrameCapture(deviceHandle, nullptr); @@ -96,28 +96,62 @@ void StartRenderDocCapture() } } -void EndRenderDocCapture() +std::string EndRenderDocCapture(bool launchReplayUI) { if (!IsRenderDocAvailable()) { std::cout << "RenderDoc not available; end capture skipped" << std::endl; - return; + return {}; } RENDERDOC_DevicePointer deviceHandle = nullptr; try { auto& ctx = getVulkanContext(); - VkDevice vkDevice = ctx.device; - deviceHandle = reinterpret_cast(vkDevice); + VkInstance vkInstance = static_cast(ctx.instance); + deviceHandle = RENDERDOC_DEVICEPOINTER_FROM_VKINSTANCE(vkInstance); } catch (const std::exception& e) { - std::cout << "RenderDoc capture end warning: failed to get Vulkan device (" << e.what() << ")" << std::endl; + std::cout << "RenderDoc capture end warning: failed to get Vulkan instance (" << e.what() << ")" << std::endl; } + std::string capturePath; const uint32_t result = gRenderDocApi->EndFrameCapture(deviceHandle, nullptr); if (result == 1) { std::cout << "RenderDoc capture end requested" << std::endl; + if (gRenderDocApi->GetNumCaptures && gRenderDocApi->GetCapture) { + const uint32_t numCaptures = gRenderDocApi->GetNumCaptures(); + if (numCaptures > 0) { + const uint32_t captureIndex = numCaptures - 1; + uint32_t pathLength = 0; + if (gRenderDocApi->GetCapture(captureIndex, nullptr, &pathLength, nullptr) == 1 && pathLength > 0) { + capturePath.resize(pathLength); + if (gRenderDocApi->GetCapture(captureIndex, capturePath.data(), &pathLength, nullptr) == 1) { + if (!capturePath.empty() && capturePath.back() == '\0') { + capturePath.pop_back(); + } + std::cout << "RenderDoc capture stored at: " << capturePath << std::endl; + if (launchReplayUI && !capturePath.empty()) { + if (gRenderDocApi->LaunchReplayUI) { + const std::string commandLine = "\"" + capturePath + "\""; + const uint32_t pid = gRenderDocApi->LaunchReplayUI(0, commandLine.c_str()); + if (pid == 0) { + std::cout << "RenderDoc replay UI launch failed" << std::endl; + } else { + std::cout << "RenderDoc replay UI launched (PID " << pid << ")" << std::endl; + } + } else { + std::cout << "RenderDoc replay UI launch unavailable (API function missing)" << std::endl; + } + } + } else { + capturePath.clear(); + } + } + } + } } else { std::cout << "RenderDoc capture end failed (" << result << ")" << std::endl; } + + return capturePath; } } // namespace TensorFrost diff --git a/TensorFrost/PybindModule.cpp b/TensorFrost/PybindModule.cpp index bc27ffe9..80e7c129 100644 --- a/TensorFrost/PybindModule.cpp +++ b/TensorFrost/PybindModule.cpp @@ -185,9 +185,13 @@ PYBIND11_MODULE(TensorFrost, m) { StartRenderDocCapture(); }, "Start a RenderDoc capture"); - m.def("renderdoc_end_capture", []() { - EndRenderDocCapture(); - }, "End the current RenderDoc capture"); + m.def("renderdoc_end_capture", + [](bool launch_replay_ui) { + return EndRenderDocCapture(launch_replay_ui); + }, + py::arg("launch_replay_ui") = false, + "End the current RenderDoc capture and optionally launch the RenderDoc replay UI. " + "Returns the capture path if one was produced, otherwise an empty string."); // VulkanContext ctx; // // const size_t N = 1024; diff --git a/examples/radix_sort/__main__.py b/examples/radix_sort/__main__.py index 466e8a6e..ce50de74 100644 --- a/examples/radix_sort/__main__.py +++ b/examples/radix_sort/__main__.py @@ -47,6 +47,7 @@ def main() -> None: renderdoc_is_available = getattr(tf, "renderdoc_is_available", None) renderdoc_start = getattr(tf, "renderdoc_start_capture", None) renderdoc_end = getattr(tf, "renderdoc_end_capture", None) + renderdoc_finalize = None if ( callable(renderdoc_is_available) and renderdoc_is_available() @@ -54,14 +55,21 @@ def main() -> None: and callable(renderdoc_end) ): renderdoc_start() + renderdoc_finalize = renderdoc_end print("RenderDoc capture started") - stack.callback(renderdoc_end) sorter = HistogramRadixSort(bits_per_pass=bits_per_pass) stack.callback(sorter.close) - start_time = time.perf_counter() - sorted_keys, sorted_values = sorter.sort(keys, values) - elapsed = time.perf_counter() - start_time + try: + start_time = time.perf_counter() + sorted_keys, sorted_values = sorter.sort(keys, values) + elapsed = time.perf_counter() - start_time + finally: + if callable(renderdoc_finalize): + capture_path = renderdoc_finalize(launch_replay_ui=True) + if capture_path: + print(f"RenderDoc capture saved to: {capture_path}") + renderdoc_finalize = None if sorted_values is None: sorted_values = np.empty_like(values) diff --git a/examples/radix_sort/shaders/scatter.slang b/examples/radix_sort/shaders/scatter.slang index 3ff2afe3..ddf2f3ac 100644 --- a/examples/radix_sort/shaders/scatter.slang +++ b/examples/radix_sort/shaders/scatter.slang @@ -1,11 +1,9 @@ static const uint GROUP_SIZE = 128; static const uint QUARTER_SIZE = GROUP_SIZE / 4; -static const uint MAX_HIST_SIZE = 256; - -uint packedCount(uint histogramSize) -{ - return (histogramSize + 3u) >> 2; -} +#ifndef TF_HISTOGRAM_SIZE +#define TF_HISTOGRAM_SIZE 256 +#endif +static const uint HISTOGRAM_SIZE = TF_HISTOGRAM_SIZE; [[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); [[vk::binding(1,0)]] StructuredBuffer KeysIn : register(t1, space0); @@ -16,32 +14,27 @@ uint packedCount(uint histogramSize) [[vk::binding(6,0)]] RWStructuredBuffer ValuesOut : register(u6, space0); groupshared uint tempBits[GROUP_SIZE]; -groupshared uint halfCount[MAX_HIST_SIZE]; +groupshared uint halfCount[HISTOGRAM_SIZE]; [numthreads(GROUP_SIZE, 1, 1)] void csScatter(uint3 groupID : SV_GroupID, uint3 localID : SV_GroupThreadID) { uint elementCount = Params[0]; - uint histogramSize = Params[1]; uint shift = Params[2]; uint mask = Params[3]; uint numGroups = Params[4]; uint hasValues = Params[7]; - if (histogramSize > MAX_HIST_SIZE) - return; - uint group = groupID.x; if (group >= numGroups) return; uint lane = localID.x; - for (uint idx = lane; idx < histogramSize; idx += GROUP_SIZE) + for (uint idx = lane; idx < HISTOGRAM_SIZE; idx += GROUP_SIZE) { halfCount[idx] = 0; } - GroupMemoryBarrierWithGroupSync(); uint globalIndex = group * GROUP_SIZE + lane; bool active = (globalIndex < elementCount); @@ -78,7 +71,7 @@ void csScatter(uint3 groupID : SV_GroupID, uint3 localID : SV_GroupThreadID) return; uint prevBucket = (bit == 0) ? 0 : BucketScan[bit - 1u]; - uint prevGroup = (group == 0) ? 0 : GroupPrefix[(group - 1u) * histogramSize + bit]; + uint prevGroup = (group == 0) ? 0 : GroupPrefix[(group - 1u) * HISTOGRAM_SIZE + bit]; uint quarterOffset = 0; if (quarterIndex > 0) diff --git a/examples/radix_sort/sort.py b/examples/radix_sort/sort.py index 68569fb7..df9535d6 100644 --- a/examples/radix_sort/sort.py +++ b/examples/radix_sort/sort.py @@ -79,6 +79,7 @@ def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: self.bits_per_pass = bits_per_pass self.block_size = block_size self.group_size = group_size + self.histogram_size = 1 << bits_per_pass self._map_to_uint_program = tf.createComputeProgramFromSlang( "radix_map_to_uint", @@ -137,9 +138,10 @@ def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: ro_count=2, rw_count=1, ) + scatter_source = f"#define TF_HISTOGRAM_SIZE {self.histogram_size}u\n" + _load_shader_source("scatter.slang") self._scatter_program = tf.createComputeProgramFromSlang( "radix_scatter", - _load_shader_source("scatter.slang"), + scatter_source, "csScatter", ro_count=5, rw_count=2, @@ -189,7 +191,7 @@ def sort( return empty_keys, values_array.copy() max_bits = int(min(max_bits, 32)) - histogram_size = 1 << self.bits_per_pass + histogram_size = self.histogram_size mask = np.uint32(histogram_size - 1) num_groups = max((element_count + self.group_size - 1) // self.group_size, 1) diff --git a/tests/renderdoc_test.py b/tests/renderdoc_test.py index 187ef6d8..8870ee51 100644 --- a/tests/renderdoc_test.py +++ b/tests/renderdoc_test.py @@ -12,7 +12,10 @@ def test_renderdoc_functions_exist(self): def test_renderdoc_capture_calls(self): # Calls shouldn't raise even when RenderDoc isn't attached. tf.renderdoc_start_capture() - tf.renderdoc_end_capture() + path = tf.renderdoc_end_capture() + self.assertIsInstance(path, str) + path = tf.renderdoc_end_capture(launch_replay_ui=False) + self.assertIsInstance(path, str) def test_renderdoc_available_returns_bool(self): self.assertIsInstance(tf.renderdoc_is_available(), bool) From 2c773fb2ed5335c54ded34131458ff710e3d738a Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Fri, 7 Nov 2025 21:02:08 +0100 Subject: [PATCH 33/44] Add test --- AGENTS.md | 9 +++ tests/slang_compile_test.py | 112 ++++++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+) create mode 100644 AGENTS.md create mode 100644 tests/slang_compile_test.py diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..ce9710e9 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,9 @@ +# Agent Guide + +Follow these expectations whenever you work in this repository: + +1. **Full rebuild & virtual environment** — Run `setup_python_env.cmd` from the repo root. It configures the Python virtual environment and performs a clean rebuild so you start from a consistent state. +2. **Partial rebuilds** — Use CMake for incremental builds. Invoke the appropriate CMake build command (for example, `cmake --build --target `) to rebuild only what you need. +3. **C++ changes** — Any edits under `TensorFrost/` or other C++ sources require a rebuild before the changes take effect. +4. **API validation** — After modifying functionality, run the relevant tests in the `tests/` folder to confirm the Python API still behaves as expected. +5. **Scenario validation** — Run the sample programs in the `examples/` folder to make sure the updated stack handles more complex end-to-end flows. diff --git a/tests/slang_compile_test.py b/tests/slang_compile_test.py new file mode 100644 index 00000000..05288e6f --- /dev/null +++ b/tests/slang_compile_test.py @@ -0,0 +1,112 @@ +import unittest +from contextlib import ExitStack +from pathlib import Path + +import numpy as np + +import TensorFrost as tf + + +_SIMPLE_SLANG_SHADER = """[[vk::binding(0,0)]] StructuredBuffer InputBuffer : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer OutputBuffer : register(u1, space0); + +[numthreads(64, 1, 1)] +void csMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + if (dispatchThreadID.x != 0u) + return; + + uint value = InputBuffer[0]; + OutputBuffer[0] = value + 1u; +} +""" + +_REQUIRED_ARTIFACTS = ( + "slang.dll", + "slang-glslang.dll", + "spirv-opt.exe", +) + + +def _runtime_dir() -> Path: + return Path(tf.__file__).resolve().parent + + +def _missing_runtime_artifacts() -> list[str]: + runtime_dir = _runtime_dir() + missing: list[str] = [] + for artifact in _REQUIRED_ARTIFACTS: + release_candidate = runtime_dir / artifact + debug_candidate = runtime_dir / release_candidate.with_name(release_candidate.stem + "d" + release_candidate.suffix).name + if not release_candidate.exists() and not debug_candidate.exists(): + missing.append(artifact) + return missing + +def _should_skip_for_backend(exc: Exception) -> bool: + message = str(exc).lower() + keywords = ( + "glfw", + "vulkan", + "no physical devices", + "no suitable", + "device", + "surface", + "swapchain", + ) + return any(token in message for token in keywords) + + +class SlangCompilationTest(unittest.TestCase): + def test_compile_and_execute_simple_shader(self) -> None: + invocation_count = 1 + + missing_artifacts = _missing_runtime_artifacts() + if missing_artifacts: # pragma: no cover - environment not staged + pretty = ", ".join(missing_artifacts) + self.skipTest( + "Slang runtime components missing: " + f"{pretty}. Re-run setup_python_env.cmd or rebuild the TensorFrost target to stage runtimes." + ) + + try: + readonly_buffer = tf.createBuffer(invocation_count, 4, True) + except RuntimeError as exc: # pragma: no cover - Vulkan not available + self.skipTest(f"Vulkan buffer creation failed: {exc}") + + with ExitStack() as resources: + resources.callback(readonly_buffer.release) + + try: + readwrite_buffer = tf.createBuffer(invocation_count, 4, False) + except RuntimeError as exc: # pragma: no cover - Vulkan not available + self.skipTest(f"Vulkan buffer creation failed: {exc}") + + resources.callback(readwrite_buffer.release) + + try: + program = tf.createComputeProgramFromSlang( + "tensorfrost_test_shader", + _SIMPLE_SLANG_SHADER, + "csMain", + ro_count=1, + rw_count=1, + ) + except RuntimeError as exc: + if _should_skip_for_backend(exc): # pragma: no cover - Vulkan not available + self.skipTest(f"Slang compilation backend unavailable: {exc}") + raise + + resources.callback(program.release) + + readonly_buffer.setData(np.array([7], dtype=np.uint32)) + readwrite_buffer.setData(np.zeros(1, dtype=np.uint32)) + + program.run([readonly_buffer], [readwrite_buffer], invocation_count) + + result = readwrite_buffer.getData(np.dtype(np.uint32), invocation_count) + self.assertEqual(result.shape, (invocation_count,)) + self.assertEqual(int(result[0]), 8) + + +if __name__ == "__main__": + unittest.main() From 155bf28544d29d70479580977fe139e4b1db28dc Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Sat, 8 Nov 2025 00:19:30 +0100 Subject: [PATCH 34/44] Properly specify group count --- Python/TensorFrost/sort.py | 35 ++++++++---- TensorFrost/Backend/include/Backend/Vulkan.h | 5 +- TensorFrost/Backend/src/Vulkan.cpp | 5 +- .../src/Definitions/VulkanBindings.cpp | 4 +- .../src/Definitions/VulkanInterface.cpp | 4 +- TensorFrost/src/Definitions/VulkanInterface.h | 2 +- examples/Slang/mandelbrot.py | 5 +- examples/debug.py | 4 +- examples/radix_sort/shaders/histogram.slang | 4 +- .../radix_sort/shaders/map_from_uint.slang | 7 ++- examples/radix_sort/shaders/map_to_uint.slang | 3 +- examples/radix_sort/shaders/scatter.slang | 5 +- examples/radix_sort/sort.py | 54 ++++++++++++++----- tests/slang_compile_test.py | 14 ++--- tests/vulkan_window_test.py | 10 ++-- 15 files changed, 108 insertions(+), 53 deletions(-) diff --git a/Python/TensorFrost/sort.py b/Python/TensorFrost/sort.py index 3db1efee..60845e5c 100644 --- a/Python/TensorFrost/sort.py +++ b/Python/TensorFrost/sort.py @@ -18,6 +18,12 @@ } +def _dispatch_groups(work_items: int, threads_per_group: int) -> int: + if work_items <= 0: + return 0 + return (work_items + threads_per_group - 1) // threads_per_group + + def _prepare_keys(keys: np.ndarray) -> Tuple[np.ndarray, np.dtype, str]: array = np.asarray(keys) if array.ndim != 1: @@ -252,10 +258,20 @@ def sort( ): stack.callback(buf.release) + map_groups = _dispatch_groups(element_count, self.group_size) + reduction_group_size = 64 + unpack_groups = _dispatch_groups(histogram_size * num_groups, reduction_group_size) + prefix_local_groups = _dispatch_groups(histogram_size * block_count, reduction_group_size) + prefix_block_groups = _dispatch_groups(histogram_size, reduction_group_size) + prefix_accum_groups = _dispatch_groups(histogram_size * block_count, reduction_group_size) + bucket_scan_groups = _dispatch_groups(histogram_size, reduction_group_size) + scatter_groups = num_groups + histogram_groups = num_groups + self._map_to_uint_program.run( [map_buffer, key_buffers[0]], [key_buffers[1]], - element_count, + map_groups, ) key_in = key_buffers[1] @@ -266,47 +282,46 @@ def sort( params_array[2] = np.uint32(pass_index * self.bits_per_pass) params_buffer.setData(params_array) - dispatch_threads = num_groups * self.group_size self._histogram_program.run( [params_buffer, key_in], [packed_hist_buffer], - dispatch_threads, + histogram_groups, ) self._unpack_program.run( [params_buffer, packed_hist_buffer], [group_hist_buffer], - histogram_size * num_groups, + unpack_groups, ) self._prefix_local_program.run( [params_buffer, group_hist_buffer], [prefix_buffer, block_totals_buffer], - histogram_size * block_count, + prefix_local_groups, ) self._prefix_blocks_program.run( [params_buffer, block_totals_buffer], [block_prefix_buffer], - histogram_size, + prefix_block_groups, ) self._prefix_accum_program.run( [params_buffer, block_prefix_buffer], [prefix_buffer], - histogram_size * block_count, + prefix_accum_groups, ) self._bucket_scan_program.run( [params_buffer, prefix_buffer], [bucket_scan_buffer], - histogram_size, + bucket_scan_groups, ) self._scatter_program.run( [params_buffer, key_in, val_in, prefix_buffer, bucket_scan_buffer], [key_out, val_out], - dispatch_threads, + scatter_groups, ) key_in, key_out = key_out, key_in @@ -316,7 +331,7 @@ def sort( self._map_from_uint_program.run( [map_buffer, key_in], [key_out], - element_count, + map_groups, ) sorted_keys = key_out.getData(key_dtype, element_count) diff --git a/TensorFrost/Backend/include/Backend/Vulkan.h b/TensorFrost/Backend/include/Backend/Vulkan.h index 001c30f7..fc8e2715 100644 --- a/TensorFrost/Backend/include/Backend/Vulkan.h +++ b/TensorFrost/Backend/include/Backend/Vulkan.h @@ -93,6 +93,9 @@ ComputeProgram createComputeProgramFromSlang(const std::string& moduleName, const std::string& source, const std::string& entry, uint32_t roCount, uint32_t rwCount); void destroyComputeProgram(ComputeProgram& prog); -void runProgram(const ComputeProgram& prog, const std::vector& readonlyBuffers, const std::vector& readwriteBuffers, uint32_t n); +void runProgram(const ComputeProgram& prog, + const std::vector& readonlyBuffers, + const std::vector& readwriteBuffers, + uint32_t groupCount); VulkanContext& getVulkanContext(); \ No newline at end of file diff --git a/TensorFrost/Backend/src/Vulkan.cpp b/TensorFrost/Backend/src/Vulkan.cpp index 76bfbabb..755220af 100644 --- a/TensorFrost/Backend/src/Vulkan.cpp +++ b/TensorFrost/Backend/src/Vulkan.cpp @@ -494,7 +494,7 @@ void destroyComputeProgram(ComputeProgram& prog) { void runProgram(const ComputeProgram& prog, const std::vector& readonlyBuffers, const std::vector& readwriteBuffers, - uint32_t n) { + uint32_t groupCount) { auto& ctx = getVulkanContext(); auto set = getOrCreateSet(ctx, prog, readonlyBuffers, readwriteBuffers); @@ -504,8 +504,7 @@ void runProgram(const ComputeProgram& prog, cmd.begin(vk::CommandBufferBeginInfo{}); cmd.bindPipeline(vk::PipelineBindPoint::eCompute, prog.pipeline); cmd.bindDescriptorSets(vk::PipelineBindPoint::eCompute, prog.pipelineLayout, 0, set, {}); - uint32_t gs = 64, groups = (n + gs - 1) / gs; - cmd.dispatch(groups, 1, 1); + cmd.dispatch(groupCount, 1, 1); cmd.end(); vk::Fence fence = ctx.device.createFence({}); diff --git a/TensorFrost/src/Definitions/VulkanBindings.cpp b/TensorFrost/src/Definitions/VulkanBindings.cpp index 28c682a3..f2be53f2 100644 --- a/TensorFrost/src/Definitions/VulkanBindings.cpp +++ b/TensorFrost/src/Definitions/VulkanBindings.cpp @@ -47,8 +47,8 @@ void VulkanDefinitions(py::module_& m) { .def_property_readonly("readwrite_count", &PyComputeProgram::readwriteCount, "Number of read-write storage buffers expected by the program.") .def("run", &PyComputeProgram::run, - py::arg("readonly_buffers"), py::arg("readwrite_buffers"), py::arg("num_invocations"), - "Dispatch the compute pipeline with the provided buffers and invocation count.") + py::arg("readonly_buffers"), py::arg("readwrite_buffers"), py::arg("group_count"), + "Dispatch the compute pipeline with the provided buffers and workgroup count.") .def("release", &PyComputeProgram::release, "Explicitly destroy the underlying Vulkan pipeline and associated resources."); diff --git a/TensorFrost/src/Definitions/VulkanInterface.cpp b/TensorFrost/src/Definitions/VulkanInterface.cpp index 40edb4d7..1dfe20cc 100644 --- a/TensorFrost/src/Definitions/VulkanInterface.cpp +++ b/TensorFrost/src/Definitions/VulkanInterface.cpp @@ -174,7 +174,7 @@ PyComputeProgram& PyComputeProgram::operator=(PyComputeProgram&& other) noexcept void PyComputeProgram::run(const py::iterable& readonlyBuffers, const py::iterable& readwriteBuffers, - uint32_t numInvocations) { + uint32_t groupCount) { ensureValid(); std::vector ro; std::vector rw; @@ -184,7 +184,7 @@ void PyComputeProgram::run(const py::iterable& readonlyBuffers, throw std::runtime_error("buffer count does not match program layout"); } py::gil_scoped_release release; - runProgram(program_, ro, rw, numInvocations); + runProgram(program_, ro, rw, groupCount); } void PyComputeProgram::release() { diff --git a/TensorFrost/src/Definitions/VulkanInterface.h b/TensorFrost/src/Definitions/VulkanInterface.h index 9d299b06..4219205c 100644 --- a/TensorFrost/src/Definitions/VulkanInterface.h +++ b/TensorFrost/src/Definitions/VulkanInterface.h @@ -72,7 +72,7 @@ class PyComputeProgram { void run(const pybind11::iterable& readonlyBuffers, const pybind11::iterable& readwriteBuffers, - uint32_t numInvocations); + uint32_t groupCount); void release(); diff --git a/examples/Slang/mandelbrot.py b/examples/Slang/mandelbrot.py index 5575ce44..b26999bd 100644 --- a/examples/Slang/mandelbrot.py +++ b/examples/Slang/mandelbrot.py @@ -26,6 +26,7 @@ def main() -> None: shader_source = load_shader() program = tf.createComputeProgramFromSlang("mandelbrot", shader_source, "csMain", ro_count=1, rw_count=1) + local_size = 64 center = [-0.5, 0.0] scale = 3.0 @@ -61,6 +62,8 @@ def ensure_pixel_buffer(cur_width: int, cur_height: int) -> None: width, height = win.size width = max(1, int(width)) height = max(1, int(height)) + thread_count = max(1, width * height) + group_count = max((thread_count + local_size - 1) // local_size, 1) ensure_pixel_buffer(width, height) @@ -136,7 +139,7 @@ def ensure_pixel_buffer(cur_width: int, cur_height: int) -> None: params[7] = 1.0 if swap_rb else 0.0 params_buffer.setData(params) - program.run([params_buffer], [pixel_buffer], width * height) + program.run([params_buffer], [pixel_buffer], group_count) win.drawBuffer(pixel_buffer, width, height) diff --git a/examples/debug.py b/examples/debug.py index 0ad93ce9..5ffab1a5 100644 --- a/examples/debug.py +++ b/examples/debug.py @@ -52,6 +52,8 @@ def main(): W, H = 1024, 768 + local_size = 64 + group_count = max((W * H + local_size - 1) // local_size, 1) win = tf.createWindow(W, H, "Mandelbrot (compute → buffer → swapchain)") fmt = int(win.format) is_bgra = fmt in (44, 50) # VK_FORMAT_B8G8R8A8_UNORM / _SRGB @@ -73,7 +75,7 @@ def main(): try: while win.isOpen(): - prog.run([params], [pix], W*H) + prog.run([params], [pix], group_count) win.drawBuffer(pix, W, H) finally: win.close() diff --git a/examples/radix_sort/shaders/histogram.slang b/examples/radix_sort/shaders/histogram.slang index 5a7592ed..01e50b47 100644 --- a/examples/radix_sort/shaders/histogram.slang +++ b/examples/radix_sort/shaders/histogram.slang @@ -1,6 +1,6 @@ -static const uint GROUP_SIZE = 128; +static const uint GROUP_SIZE = TF_GROUP_SIZE; static const uint QUARTER_SIZE = GROUP_SIZE / 4; -static const uint MAX_HIST_SIZE = 256; +static const uint MAX_HIST_SIZE = 256u; uint packedCount(uint histogramSize) { diff --git a/examples/radix_sort/shaders/map_from_uint.slang b/examples/radix_sort/shaders/map_from_uint.slang index 86f819f9..f9542ffe 100644 --- a/examples/radix_sort/shaders/map_from_uint.slang +++ b/examples/radix_sort/shaders/map_from_uint.slang @@ -4,11 +4,16 @@ static const uint TYPE_FLOAT = 2u; static const uint SIGN_BIT = 0x80000000u; static const uint FULL_MASK = 0xFFFFFFFFu; +#if !defined(TF_GROUP_SIZE) +#error "TF_GROUP_SIZE must be defined" +#endif +static const uint GROUP_SIZE = TF_GROUP_SIZE; + [[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); [[vk::binding(1,0)]] StructuredBuffer Input : register(t1, space0); [[vk::binding(2,0)]] RWStructuredBuffer Output : register(u2, space0); -[numthreads(128, 1, 1)] +[numthreads(GROUP_SIZE, 1, 1)] void csMapFromUint(uint3 dispatchThreadID : SV_DispatchThreadID) { uint index = dispatchThreadID.x; diff --git a/examples/radix_sort/shaders/map_to_uint.slang b/examples/radix_sort/shaders/map_to_uint.slang index b6be57f5..6046235f 100644 --- a/examples/radix_sort/shaders/map_to_uint.slang +++ b/examples/radix_sort/shaders/map_to_uint.slang @@ -3,12 +3,13 @@ static const uint TYPE_INT = 1u; static const uint TYPE_FLOAT = 2u; static const uint SIGN_BIT = 0x80000000u; static const uint FULL_MASK = 0xFFFFFFFFu; +static const uint GROUP_SIZE = TF_GROUP_SIZE; [[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); [[vk::binding(1,0)]] StructuredBuffer Input : register(t1, space0); [[vk::binding(2,0)]] RWStructuredBuffer Output : register(u2, space0); -[numthreads(128, 1, 1)] +[numthreads(GROUP_SIZE, 1, 1)] void csMapToUint(uint3 dispatchThreadID : SV_DispatchThreadID) { uint index = dispatchThreadID.x; diff --git a/examples/radix_sort/shaders/scatter.slang b/examples/radix_sort/shaders/scatter.slang index ddf2f3ac..f23ba79d 100644 --- a/examples/radix_sort/shaders/scatter.slang +++ b/examples/radix_sort/shaders/scatter.slang @@ -1,8 +1,5 @@ -static const uint GROUP_SIZE = 128; +static const uint GROUP_SIZE = TF_GROUP_SIZE; static const uint QUARTER_SIZE = GROUP_SIZE / 4; -#ifndef TF_HISTOGRAM_SIZE -#define TF_HISTOGRAM_SIZE 256 -#endif static const uint HISTOGRAM_SIZE = TF_HISTOGRAM_SIZE; [[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); diff --git a/examples/radix_sort/sort.py b/examples/radix_sort/sort.py index df9535d6..aba33af3 100644 --- a/examples/radix_sort/sort.py +++ b/examples/radix_sort/sort.py @@ -18,6 +18,12 @@ } +def _dispatch_groups(work_items: int, threads_per_group: int) -> int: + if work_items <= 0: + return 0 + return (work_items + threads_per_group - 1) // threads_per_group + + def _prepare_keys(keys: np.ndarray) -> Tuple[np.ndarray, np.dtype, str]: array = np.asarray(keys) if array.ndim != 1: @@ -81,16 +87,27 @@ def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: self.group_size = group_size self.histogram_size = 1 << bits_per_pass + def inject_defines(filename: str, *, with_group: bool = False, with_histogram: bool = False) -> str: + defines = [] + if with_group: + defines.append(f"#define TF_GROUP_SIZE {self.group_size}u") + if with_histogram: + defines.append(f"#define TF_HISTOGRAM_SIZE {self.histogram_size}u") + source = _load_shader_source(filename) + if defines: + return "\n".join(defines) + "\n" + source + return source + self._map_to_uint_program = tf.createComputeProgramFromSlang( "radix_map_to_uint", - _load_shader_source("map_to_uint.slang"), + inject_defines("map_to_uint.slang", with_group=True), "csMapToUint", ro_count=2, rw_count=1, ) self._map_from_uint_program = tf.createComputeProgramFromSlang( "radix_map_from_uint", - _load_shader_source("map_from_uint.slang"), + inject_defines("map_from_uint.slang", with_group=True), "csMapFromUint", ro_count=2, rw_count=1, @@ -98,7 +115,7 @@ def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: self._histogram_program = tf.createComputeProgramFromSlang( "radix_histogram", - _load_shader_source("histogram.slang"), + inject_defines("histogram.slang", with_group=True), "csHistogram", ro_count=2, rw_count=1, @@ -138,7 +155,7 @@ def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: ro_count=2, rw_count=1, ) - scatter_source = f"#define TF_HISTOGRAM_SIZE {self.histogram_size}u\n" + _load_shader_source("scatter.slang") + scatter_source = inject_defines("scatter.slang", with_group=True, with_histogram=True) self._scatter_program = tf.createComputeProgramFromSlang( "radix_scatter", scatter_source, @@ -252,10 +269,20 @@ def sort( ): stack.callback(buf.release) + map_groups = _dispatch_groups(element_count, self.group_size) + reduction_group_size = 64 + unpack_groups = _dispatch_groups(histogram_size * num_groups, reduction_group_size) + prefix_local_groups = _dispatch_groups(histogram_size * block_count, reduction_group_size) + prefix_block_groups = _dispatch_groups(histogram_size, reduction_group_size) + prefix_accum_groups = _dispatch_groups(histogram_size * block_count, reduction_group_size) + bucket_scan_groups = _dispatch_groups(histogram_size, reduction_group_size) + scatter_groups = num_groups + histogram_groups = num_groups + self._map_to_uint_program.run( [map_buffer, key_buffers[0]], [key_buffers[1]], - element_count, + map_groups, ) key_in = key_buffers[1] @@ -266,47 +293,46 @@ def sort( params_array[2] = np.uint32(pass_index * self.bits_per_pass) params_buffer.setData(params_array) - dispatch_threads = num_groups * self.group_size self._histogram_program.run( [params_buffer, key_in], [packed_hist_buffer], - dispatch_threads, + histogram_groups, ) self._unpack_program.run( [params_buffer, packed_hist_buffer], [group_hist_buffer], - histogram_size * num_groups, + unpack_groups, ) self._prefix_local_program.run( [params_buffer, group_hist_buffer], [prefix_buffer, block_totals_buffer], - histogram_size * block_count, + prefix_local_groups, ) self._prefix_blocks_program.run( [params_buffer, block_totals_buffer], [block_prefix_buffer], - histogram_size, + prefix_block_groups, ) self._prefix_accum_program.run( [params_buffer, block_prefix_buffer], [prefix_buffer], - histogram_size * block_count, + prefix_accum_groups, ) self._bucket_scan_program.run( [params_buffer, prefix_buffer], [bucket_scan_buffer], - histogram_size, + bucket_scan_groups, ) self._scatter_program.run( [params_buffer, key_in, val_in, prefix_buffer, bucket_scan_buffer], [key_out, val_out], - dispatch_threads, + scatter_groups, ) key_in, key_out = key_out, key_in @@ -316,7 +342,7 @@ def sort( self._map_from_uint_program.run( [map_buffer, key_in], [key_out], - element_count, + map_groups, ) sorted_keys = key_out.getData(key_dtype, element_count) diff --git a/tests/slang_compile_test.py b/tests/slang_compile_test.py index 05288e6f..0f99e2f2 100644 --- a/tests/slang_compile_test.py +++ b/tests/slang_compile_test.py @@ -58,7 +58,9 @@ def _should_skip_for_backend(exc: Exception) -> bool: class SlangCompilationTest(unittest.TestCase): def test_compile_and_execute_simple_shader(self) -> None: - invocation_count = 1 + thread_count = 1 + local_size = 64 + group_count = max((thread_count + local_size - 1) // local_size, 1) missing_artifacts = _missing_runtime_artifacts() if missing_artifacts: # pragma: no cover - environment not staged @@ -69,7 +71,7 @@ def test_compile_and_execute_simple_shader(self) -> None: ) try: - readonly_buffer = tf.createBuffer(invocation_count, 4, True) + readonly_buffer = tf.createBuffer(thread_count, 4, True) except RuntimeError as exc: # pragma: no cover - Vulkan not available self.skipTest(f"Vulkan buffer creation failed: {exc}") @@ -77,7 +79,7 @@ def test_compile_and_execute_simple_shader(self) -> None: resources.callback(readonly_buffer.release) try: - readwrite_buffer = tf.createBuffer(invocation_count, 4, False) + readwrite_buffer = tf.createBuffer(thread_count, 4, False) except RuntimeError as exc: # pragma: no cover - Vulkan not available self.skipTest(f"Vulkan buffer creation failed: {exc}") @@ -101,10 +103,10 @@ def test_compile_and_execute_simple_shader(self) -> None: readonly_buffer.setData(np.array([7], dtype=np.uint32)) readwrite_buffer.setData(np.zeros(1, dtype=np.uint32)) - program.run([readonly_buffer], [readwrite_buffer], invocation_count) + program.run([readonly_buffer], [readwrite_buffer], group_count) - result = readwrite_buffer.getData(np.dtype(np.uint32), invocation_count) - self.assertEqual(result.shape, (invocation_count,)) + result = readwrite_buffer.getData(np.dtype(np.uint32), thread_count) + self.assertEqual(result.shape, (thread_count,)) self.assertEqual(int(result[0]), 8) diff --git a/tests/vulkan_window_test.py b/tests/vulkan_window_test.py index 74fc6672..004d78ad 100644 --- a/tests/vulkan_window_test.py +++ b/tests/vulkan_window_test.py @@ -21,10 +21,12 @@ class VulkanWindowTest(unittest.TestCase): def test_compute_dispatch_and_window_present(self): width = height = 8 - invocation_count = width * height + thread_count = width * height + local_size = 64 + group_count = max((thread_count + local_size - 1) // local_size, 1) try: - pixel_buffer = tf.createBuffer(invocation_count, 4, False) + pixel_buffer = tf.createBuffer(thread_count, 4, False) except RuntimeError as exc: # pragma: no cover - Vulkan not available self.skipTest(f"Vulkan buffer creation failed: {exc}") @@ -34,8 +36,8 @@ def test_compute_dispatch_and_window_present(self): program = tf.createComputeProgramFromGLSL(_SIMPLE_GLSL, ro_count=0, rw_count=1) resources.callback(program.release) - program.run([], [pixel_buffer], invocation_count) - pixels = pixel_buffer.getData(np.dtype(np.uint32), invocation_count) + program.run([], [pixel_buffer], group_count) + pixels = pixel_buffer.getData(np.dtype(np.uint32), thread_count) self.assertTrue(np.all(pixels == 0xFF3366FF), "Compute shader did not write expected color") try: From f304c4f2bd96ce03e11caffe56fed94c15697390 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Sat, 8 Nov 2025 02:07:35 +0100 Subject: [PATCH 35/44] Update sorting example --- .gitignore | 1 + examples/radix_sort/__main__.py | 305 +++++++++++++----- examples/radix_sort/shaders/bucket_scan.slang | 9 - .../radix_sort/shaders/map_from_uint.slang | 4 - .../radix_sort/shaders/prefix_accum.slang | 9 - .../radix_sort/shaders/prefix_block.slang | 9 - .../radix_sort/shaders/prefix_local.slang | 9 - examples/radix_sort/shaders/scatter.slang | 48 ++- examples/radix_sort/shaders/unpack.slang | 4 - examples/radix_sort/sort.py | 48 +++ 10 files changed, 305 insertions(+), 141 deletions(-) diff --git a/.gitignore b/.gitignore index 52156960..b64dee42 100644 --- a/.gitignore +++ b/.gitignore @@ -58,3 +58,4 @@ imgui.ini *.pyc /.cmake /CMakeFiles +/.debug diff --git a/examples/radix_sort/__main__.py b/examples/radix_sort/__main__.py index ce50de74..ea3931b2 100644 --- a/examples/radix_sort/__main__.py +++ b/examples/radix_sort/__main__.py @@ -1,94 +1,255 @@ from __future__ import annotations import argparse +import math import time +from collections import deque from contextlib import ExitStack import numpy as np import TensorFrost as tf try: - from .sort import HistogramRadixSort + from .sort import HistogramRadixSort except ImportError: - import sys - from pathlib import Path + import sys + from pathlib import Path - _CURRENT_DIR = Path(__file__).resolve().parent - if str(_CURRENT_DIR) not in sys.path: - sys.path.insert(0, str(_CURRENT_DIR)) + _CURRENT_DIR = Path(__file__).resolve().parent + if str(_CURRENT_DIR) not in sys.path: + sys.path.insert(0, str(_CURRENT_DIR)) - from sort import HistogramRadixSort + from sort import HistogramRadixSort + + +_STAGE_NAMES = ( + "map_to_uint", + "histogram", + "unpack", + "prefix_local", + "prefix_blocks", + "prefix_accum", + "bucket_scan", + "scatter", + "map_from_uint", +) def _select_backend() -> None: - if hasattr(tf, "initialize"): - backend = getattr(tf, "vulkan", None) - if backend is None: - raise RuntimeError("TensorFrost Vulkan backend is unavailable on this build") - tf.initialize(backend) + if hasattr(tf, "initialize"): + backend = getattr(tf, "vulkan", None) + if backend is None: + raise RuntimeError("TensorFrost Vulkan backend is unavailable on this build") + tf.initialize(backend) + + +def _format_rate(value: float, unit: str) -> str: + if value <= 0.0 or not math.isfinite(value): + return f"0 {unit}" + + prefixes = ("", "K", "M", "G", "T", "P") + magnitude = 0 + while value >= 1000.0 and magnitude < len(prefixes) - 1: + value /= 1000.0 + magnitude += 1 + + return f"{value:6.2f} {prefixes[magnitude]}{unit}" + + +def _format_stage_name(name: str) -> str: + return name.replace("_", " ").title() def main() -> None: - parser = argparse.ArgumentParser(description="Histogram radix sort demo running on the Vulkan backend.") - parser.add_argument("--size", type=int, default=1 << 20, help="Number of key/value pairs to sort") - parser.add_argument("--bits", type=int, default=6, help="Bits processed per pass") - args = parser.parse_args() - - _select_backend() - - count = max(0, int(args.size)) - bits_per_pass = max(1, int(args.bits)) - rng = np.random.default_rng(1234) - - keys = rng.standard_normal(count, dtype=np.float32) - values = rng.integers(0, 1 << 31, size=count, dtype=np.uint32) - - with ExitStack() as stack: - renderdoc_is_available = getattr(tf, "renderdoc_is_available", None) - renderdoc_start = getattr(tf, "renderdoc_start_capture", None) - renderdoc_end = getattr(tf, "renderdoc_end_capture", None) - renderdoc_finalize = None - if ( - callable(renderdoc_is_available) - and renderdoc_is_available() - and callable(renderdoc_start) - and callable(renderdoc_end) - ): - renderdoc_start() - renderdoc_finalize = renderdoc_end - print("RenderDoc capture started") - - sorter = HistogramRadixSort(bits_per_pass=bits_per_pass) - stack.callback(sorter.close) - try: - start_time = time.perf_counter() - sorted_keys, sorted_values = sorter.sort(keys, values) - elapsed = time.perf_counter() - start_time - finally: - if callable(renderdoc_finalize): - capture_path = renderdoc_finalize(launch_replay_ui=True) - if capture_path: - print(f"RenderDoc capture saved to: {capture_path}") - renderdoc_finalize = None - - if sorted_values is None: - sorted_values = np.empty_like(values) - - order = np.argsort(keys, kind="stable") - reference_keys = keys[order] - reference_values = values[order] - - key_match = np.allclose(sorted_keys, reference_keys, atol=0.0, rtol=0.0) - value_match = np.array_equal(sorted_values, reference_values) - - print(f"Sorted {count} elements with bits_per_pass={bits_per_pass}") - print(f"Sort elapsed: {elapsed * 1e3:.3f} ms ({elapsed:.6f} s)") - print(f"Keys match reference: {key_match}") - print(f"Values match reference: {value_match}") - if count: - preview = min(10, count) - print("First few sorted keys:") - print(sorted_keys[:preview]) + parser = argparse.ArgumentParser( + description="Interactive histogram radix sort demo running continuously on the Vulkan backend.", + ) + parser.add_argument("--size", type=int, default=1 << 20, help="Number of key/value pairs to sort") + parser.add_argument("--bits", type=int, default=6, help="Bits processed per pass") + parser.add_argument("--window-width", type=int, default=960, help="Window width in pixels") + parser.add_argument("--window-height", type=int, default=540, help="Window height in pixels") + parser.add_argument("--font-scale", type=float, default=1.6, help="Global ImGui font scale") + parser.add_argument("--history", type=int, default=240, help="Number of samples stored for metrics history") + parser.add_argument("--seed", type=int, default=1337, help="Seed for the random data generator") + args = parser.parse_args() + + _select_backend() + + count = max(0, int(args.size)) + bits_per_pass = max(1, int(args.bits)) + window_width = max(320, int(args.window_width)) + window_height = max(240, int(args.window_height)) + history_length = max(1, int(args.history)) + font_scale = max(0.1, float(args.font_scale)) if args.font_scale > 0 else 0.0 + + rng = np.random.default_rng(int(args.seed)) + keys = ( + rng.standard_normal(count, dtype=np.float32) + if count + else np.empty(0, dtype=np.float32) + ) + values = ( + rng.integers(0, 1 << 31, size=count, dtype=np.uint32) + if count + else np.empty(0, dtype=np.uint32) + ) + + reference_keys = None + reference_values = None + if count: + order = np.argsort(keys, kind="stable") + reference_keys = keys[order] + reference_values = values[order] + + frame_times = deque(maxlen=history_length) + sort_times = deque(maxlen=history_length) + stage_totals_overall = {name: 0.0 for name in _STAGE_NAMES} + + validated = False + validation_ok = False + validation_message = "" + + sort_count = 0 + total_kernel_time = 0.0 + + with ExitStack() as stack: + sorter = HistogramRadixSort(bits_per_pass=bits_per_pass) + stack.callback(sorter.close) + + window_title = f"TensorFrost Radix Sort ({count:,} elements)" + window = tf.createWindow(window_width, window_height, window_title) + stack.callback(window.close) + + if font_scale > 0.0: + window.imgui_scale_all_sizes(font_scale) + window.imgui_set_font_global_scale(font_scale) + + start_time = time.perf_counter() + last_frame_time = start_time + + while window.isOpen(): + now = time.perf_counter() + dt = now - last_frame_time + last_frame_time = now + + if dt > 0.0: + frame_times.append(dt) + frame_time_total = sum(frame_times) + fps = (len(frame_times) / frame_time_total) if frame_times and frame_time_total > 0.0 else 0.0 + + sorted_keys, sorted_values = sorter.sort(keys, values, collect_stage_timings=True) + stage_timings = sorter.last_stage_timings or {} + kernel_time = float(sum(stage_timings.values())) + + if sort_times.maxlen == len(sort_times): + sort_times.popleft() + sort_times.append(kernel_time) + + sort_count += 1 + total_kernel_time += kernel_time + for name in _STAGE_NAMES: + stage_totals_overall[name] += stage_timings.get(name, 0.0) + + if not validated and reference_keys is not None and reference_values is not None: + key_match = np.allclose(sorted_keys, reference_keys, atol=0.0, rtol=0.0) + value_match = np.array_equal(sorted_values, reference_values) + validation_ok = bool(key_match and value_match) + validation_message = ( + "GPU results match CPU reference." if validation_ok else "Mismatch detected against CPU reference!" + ) + validated = True + + # Release arrays promptly once consumed. + del sorted_keys + del sorted_values + + window_avg_sort = (sum(sort_times) / len(sort_times)) if sort_times else 0.0 + last_sort_ms = kernel_time * 1000.0 + avg_sort_ms = window_avg_sort * 1000.0 + + last_sort_rate = (1.0 / kernel_time) if kernel_time > 0.0 else 0.0 + window_sort_rate = (1.0 / window_avg_sort) if window_avg_sort > 0.0 else 0.0 + overall_avg_sort = (total_kernel_time / sort_count) if sort_count else 0.0 + overall_sort_rate = (1.0 / overall_avg_sort) if overall_avg_sort > 0.0 else 0.0 + + last_elements_per_sec = (count / kernel_time) if count and kernel_time > 0.0 else 0.0 + window_elements_per_sec = (count / window_avg_sort) if count and window_avg_sort > 0.0 else 0.0 + + history_data = np.array(sort_times, dtype=np.float32) * 1000.0 + total_elapsed = now - start_time + + visible, _ = window.imgui_begin("Radix Sort Performance", open=None, flags=0) + if visible: + window.imgui_text(f"Elements: {count:,}") + window.imgui_text(f"Bits per pass: {bits_per_pass}") + window.imgui_separator() + window.imgui_text(f"Frame time: {dt * 1000.0:7.3f} ms") + window.imgui_text(f"Average FPS: {fps:6.1f}") + window.imgui_separator() + + if sort_count: + window.imgui_text(f"Last kernel time: {last_sort_ms:7.3f} ms") + if window_avg_sort > 0.0: + window.imgui_text(f"Window avg kernel: {avg_sort_ms:7.3f} ms") + if overall_avg_sort > 0.0: + window.imgui_text(f"Overall avg kernel: {overall_avg_sort * 1000.0:7.3f} ms") + + window.imgui_spacing() + window.imgui_text(f"Sorts/sec (last): {last_sort_rate:8.2f}") + window.imgui_text(f"Sorts/sec (window): {window_sort_rate:8.2f}") + window.imgui_text(f"Sorts/sec (overall): {overall_sort_rate:8.2f}") + + if count: + window.imgui_spacing() + window.imgui_text( + f"Elements/sec (last): {_format_rate(last_elements_per_sec, ' elements/s')}" + ) + window.imgui_text( + f"Elements/sec (window): {_format_rate(window_elements_per_sec, ' elements/s')}" + ) + + window.imgui_spacing() + window.imgui_text(f"Total sorts: {sort_count:,}") + else: + if count: + window.imgui_text("Waiting for first GPU sort...") + else: + window.imgui_text("No elements to sort.") + + if validated: + color = (0.20, 0.80, 0.40, 1.0) if validation_ok else (0.92, 0.35, 0.32, 1.0) + window.imgui_text_colored(color, validation_message) + elif count: + window.imgui_text("Validating against CPU reference...") + + if history_data.size: + max_plot = float(np.max(history_data)) if history_data.size else 1.0 + max_plot = max(1.0, max_plot * 1.1) + window.imgui_plot_lines( + "Kernel time history (ms)", + history_data, + values_offset=0, + overlay_text=f"{history_data.size} sample history", + scale_min=0.0, + scale_max=max_plot, + graph_size=(0.0, 140.0), + stride=4, + ) + + if stage_timings: + window.imgui_separator() + window.imgui_text("Kernel stages (last run):") + for name, duration in sorted(stage_timings.items(), key=lambda item: item[1], reverse=True): + window.imgui_text(f"{_format_stage_name(name)}: {duration * 1000.0:7.3f} ms") + + window.imgui_spacing() + window.imgui_text("Kernel stages (overall avg):") + for name in sorted(stage_totals_overall.keys()): + avg_duration = (stage_totals_overall[name] / sort_count) if sort_count else 0.0 + window.imgui_text(f"{_format_stage_name(name)}: {avg_duration * 1000.0:7.3f} ms") + + window.imgui_end() + window.present() if __name__ == "__main__": diff --git a/examples/radix_sort/shaders/bucket_scan.slang b/examples/radix_sort/shaders/bucket_scan.slang index e748d887..43d18de9 100644 --- a/examples/radix_sort/shaders/bucket_scan.slang +++ b/examples/radix_sort/shaders/bucket_scan.slang @@ -1,12 +1,3 @@ -static const uint GROUP_SIZE = 128; -static const uint QUARTER_SIZE = GROUP_SIZE / 4; -static const uint MAX_HIST_SIZE = 256; - -uint packedCount(uint histogramSize) -{ - return (histogramSize + 3u) >> 2; -} - [[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); [[vk::binding(1,0)]] StructuredBuffer GroupPrefix : register(t1, space0); [[vk::binding(2,0)]] RWStructuredBuffer BucketScan : register(u2, space0); diff --git a/examples/radix_sort/shaders/map_from_uint.slang b/examples/radix_sort/shaders/map_from_uint.slang index f9542ffe..643b4509 100644 --- a/examples/radix_sort/shaders/map_from_uint.slang +++ b/examples/radix_sort/shaders/map_from_uint.slang @@ -3,10 +3,6 @@ static const uint TYPE_INT = 1u; static const uint TYPE_FLOAT = 2u; static const uint SIGN_BIT = 0x80000000u; static const uint FULL_MASK = 0xFFFFFFFFu; - -#if !defined(TF_GROUP_SIZE) -#error "TF_GROUP_SIZE must be defined" -#endif static const uint GROUP_SIZE = TF_GROUP_SIZE; [[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); diff --git a/examples/radix_sort/shaders/prefix_accum.slang b/examples/radix_sort/shaders/prefix_accum.slang index 81966007..26e7df66 100644 --- a/examples/radix_sort/shaders/prefix_accum.slang +++ b/examples/radix_sort/shaders/prefix_accum.slang @@ -1,12 +1,3 @@ -static const uint GROUP_SIZE = 128; -static const uint QUARTER_SIZE = GROUP_SIZE / 4; -static const uint MAX_HIST_SIZE = 256; - -uint packedCount(uint histogramSize) -{ - return (histogramSize + 3u) >> 2; -} - [[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); [[vk::binding(1,0)]] StructuredBuffer BlockPrefix : register(t1, space0); [[vk::binding(2,0)]] RWStructuredBuffer GroupPrefix : register(u2, space0); diff --git a/examples/radix_sort/shaders/prefix_block.slang b/examples/radix_sort/shaders/prefix_block.slang index 3a4909b0..5c526574 100644 --- a/examples/radix_sort/shaders/prefix_block.slang +++ b/examples/radix_sort/shaders/prefix_block.slang @@ -1,12 +1,3 @@ -static const uint GROUP_SIZE = 128; -static const uint QUARTER_SIZE = GROUP_SIZE / 4; -static const uint MAX_HIST_SIZE = 256; - -uint packedCount(uint histogramSize) -{ - return (histogramSize + 3u) >> 2; -} - [[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); [[vk::binding(1,0)]] StructuredBuffer BlockTotals : register(t1, space0); [[vk::binding(2,0)]] RWStructuredBuffer BlockPrefix : register(u2, space0); diff --git a/examples/radix_sort/shaders/prefix_local.slang b/examples/radix_sort/shaders/prefix_local.slang index 5e70d6dc..14bed0de 100644 --- a/examples/radix_sort/shaders/prefix_local.slang +++ b/examples/radix_sort/shaders/prefix_local.slang @@ -1,12 +1,3 @@ -static const uint GROUP_SIZE = 128; -static const uint QUARTER_SIZE = GROUP_SIZE / 4; -static const uint MAX_HIST_SIZE = 256; - -uint packedCount(uint histogramSize) -{ - return (histogramSize + 3u) >> 2; -} - [[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); [[vk::binding(1,0)]] StructuredBuffer GroupHistogram : register(t1, space0); [[vk::binding(2,0)]] RWStructuredBuffer GroupPrefix : register(u2, space0); diff --git a/examples/radix_sort/shaders/scatter.slang b/examples/radix_sort/shaders/scatter.slang index f23ba79d..f3b7a11a 100644 --- a/examples/radix_sort/shaders/scatter.slang +++ b/examples/radix_sort/shaders/scatter.slang @@ -28,9 +28,9 @@ void csScatter(uint3 groupID : SV_GroupID, uint3 localID : SV_GroupThreadID) uint lane = localID.x; - for (uint idx = lane; idx < HISTOGRAM_SIZE; idx += GROUP_SIZE) + [unroll] for (uint idx = lane; idx < HISTOGRAM_SIZE; idx += GROUP_SIZE) { - halfCount[idx] = 0; + if (idx < HISTOGRAM_SIZE) halfCount[idx] = 0; } uint globalIndex = group * GROUP_SIZE + lane; @@ -64,31 +64,29 @@ void csScatter(uint3 groupID : SV_GroupID, uint3 localID : SV_GroupThreadID) } GroupMemoryBarrierWithGroupSync(); - if (!active) - return; - - uint prevBucket = (bit == 0) ? 0 : BucketScan[bit - 1u]; - uint prevGroup = (group == 0) ? 0 : GroupPrefix[(group - 1u) * HISTOGRAM_SIZE + bit]; + if (active) { + uint prevBucket = (bit == 0) ? 0 : BucketScan[bit - 1u]; + uint prevGroup = (group == 0) ? 0 : GroupPrefix[(group - 1u) * HISTOGRAM_SIZE + bit]; - uint quarterOffset = 0; - if (quarterIndex > 0) - { - uint packed = halfCount[bit]; - uint shiftBytes = (quarterIndex - 1u) * 8u; - quarterOffset = (packed >> shiftBytes) & 0xFFu; - } + uint quarterOffset = 0; + if (quarterIndex > 0) + { + uint packed = halfCount[bit]; + uint shiftBytes = (quarterIndex - 1u) * 8u; + quarterOffset = (packed >> shiftBytes) & 0xFFu; + } - uint begin = quarterIndex * QUARTER_SIZE; - uint localCount = 0; - for (uint t = begin; t < lane; ++t) - { - if (tempBits[t] == bit) - ++localCount; - } + uint begin = quarterIndex * QUARTER_SIZE; + uint localCount = 0; + [unroll] for (uint t = begin; t < lane; ++t) + { + localCount += (tempBits[t] == bit) ? 1u : 0u; + } - uint totalOffset = prevBucket + prevGroup + quarterOffset + localCount; - KeysOut[totalOffset] = key; + uint totalOffset = prevBucket + prevGroup + quarterOffset + localCount; + KeysOut[totalOffset] = key; - if (hasValues != 0) - ValuesOut[totalOffset] = value; + if (hasValues != 0) + ValuesOut[totalOffset] = value; + } } diff --git a/examples/radix_sort/shaders/unpack.slang b/examples/radix_sort/shaders/unpack.slang index ee53179b..03eb29af 100644 --- a/examples/radix_sort/shaders/unpack.slang +++ b/examples/radix_sort/shaders/unpack.slang @@ -1,7 +1,3 @@ -static const uint GROUP_SIZE = 128; -static const uint QUARTER_SIZE = GROUP_SIZE / 4; -static const uint MAX_HIST_SIZE = 256; - uint packedCount(uint histogramSize) { return (histogramSize + 3u) >> 2; diff --git a/examples/radix_sort/sort.py b/examples/radix_sort/sort.py index aba33af3..b79cbd7e 100644 --- a/examples/radix_sort/sort.py +++ b/examples/radix_sort/sort.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from pathlib import Path from typing import Dict, Optional, Tuple +import time import numpy as np @@ -86,6 +87,7 @@ def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: self.block_size = block_size self.group_size = group_size self.histogram_size = 1 << bits_per_pass + self.last_stage_timings = None def inject_defines(filename: str, *, with_group: bool = False, with_histogram: bool = False) -> str: defines = [] @@ -189,6 +191,7 @@ def sort( values: Optional[np.ndarray] = None, *, max_bits: int = 32, + collect_stage_timings: bool = False, ) -> Tuple[np.ndarray, Optional[np.ndarray]]: keys_array, key_dtype, key_kind = _prepare_keys(keys) element_count = int(keys_array.shape[0]) @@ -203,6 +206,7 @@ def sort( if element_count == 0: empty_keys = keys_array.copy() + self.last_stage_timings = {} if collect_stage_timings else None if values_array is None: return empty_keys, None return empty_keys, values_array.copy() @@ -229,6 +233,21 @@ def sort( map_params[0] = np.uint32(element_count) map_params[1] = _TYPE_CODES[key_kind] + if collect_stage_timings: + stage_totals = { + "map_to_uint": 0.0, + "histogram": 0.0, + "unpack": 0.0, + "prefix_local": 0.0, + "prefix_blocks": 0.0, + "prefix_accum": 0.0, + "bucket_scan": 0.0, + "scatter": 0.0, + "map_from_uint": 0.0, + } + else: + stage_totals = {} + with ExitStack() as stack: params_buffer = tf.createBuffer(params_array.size, 4, True) stack.callback(params_buffer.release) @@ -279,11 +298,14 @@ def sort( scatter_groups = num_groups histogram_groups = num_groups + start = time.perf_counter() if collect_stage_timings else None self._map_to_uint_program.run( [map_buffer, key_buffers[0]], [key_buffers[1]], map_groups, ) + if collect_stage_timings and start is not None: + stage_totals["map_to_uint"] += time.perf_counter() - start key_in = key_buffers[1] key_out = key_buffers[0] @@ -293,57 +315,81 @@ def sort( params_array[2] = np.uint32(pass_index * self.bits_per_pass) params_buffer.setData(params_array) + start = time.perf_counter() if collect_stage_timings else None self._histogram_program.run( [params_buffer, key_in], [packed_hist_buffer], histogram_groups, ) + if collect_stage_timings and start is not None: + stage_totals["histogram"] += time.perf_counter() - start + start = time.perf_counter() if collect_stage_timings else None self._unpack_program.run( [params_buffer, packed_hist_buffer], [group_hist_buffer], unpack_groups, ) + if collect_stage_timings and start is not None: + stage_totals["unpack"] += time.perf_counter() - start + start = time.perf_counter() if collect_stage_timings else None self._prefix_local_program.run( [params_buffer, group_hist_buffer], [prefix_buffer, block_totals_buffer], prefix_local_groups, ) + if collect_stage_timings and start is not None: + stage_totals["prefix_local"] += time.perf_counter() - start + start = time.perf_counter() if collect_stage_timings else None self._prefix_blocks_program.run( [params_buffer, block_totals_buffer], [block_prefix_buffer], prefix_block_groups, ) + if collect_stage_timings and start is not None: + stage_totals["prefix_blocks"] += time.perf_counter() - start + start = time.perf_counter() if collect_stage_timings else None self._prefix_accum_program.run( [params_buffer, block_prefix_buffer], [prefix_buffer], prefix_accum_groups, ) + if collect_stage_timings and start is not None: + stage_totals["prefix_accum"] += time.perf_counter() - start + start = time.perf_counter() if collect_stage_timings else None self._bucket_scan_program.run( [params_buffer, prefix_buffer], [bucket_scan_buffer], bucket_scan_groups, ) + if collect_stage_timings and start is not None: + stage_totals["bucket_scan"] += time.perf_counter() - start + start = time.perf_counter() if collect_stage_timings else None self._scatter_program.run( [params_buffer, key_in, val_in, prefix_buffer, bucket_scan_buffer], [key_out, val_out], scatter_groups, ) + if collect_stage_timings and start is not None: + stage_totals["scatter"] += time.perf_counter() - start key_in, key_out = key_out, key_in if values_array is not None: val_in, val_out = val_out, val_in + start = time.perf_counter() if collect_stage_timings else None self._map_from_uint_program.run( [map_buffer, key_in], [key_out], map_groups, ) + if collect_stage_timings and start is not None: + stage_totals["map_from_uint"] += time.perf_counter() - start sorted_keys = key_out.getData(key_dtype, element_count) if values_array is not None and values_dtype is not None: @@ -351,7 +397,9 @@ def sort( else: sorted_values = None + self.last_stage_timings = stage_totals if collect_stage_timings else None return sorted_keys, sorted_values + return sorted_keys, sorted_values, None _SORTER_CACHE: Dict[_SorterKey, HistogramRadixSort] = {} From 426da84779fbaf930064921af86c4e34e9b085bc Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Sat, 8 Nov 2025 04:50:40 +0100 Subject: [PATCH 36/44] Use deviceLocal buffers only --- TensorFrost/Backend/src/Vulkan.cpp | 143 +++++++++++++++++++++++++---- 1 file changed, 127 insertions(+), 16 deletions(-) diff --git a/TensorFrost/Backend/src/Vulkan.cpp b/TensorFrost/Backend/src/Vulkan.cpp index 755220af..c9cce616 100644 --- a/TensorFrost/Backend/src/Vulkan.cpp +++ b/TensorFrost/Backend/src/Vulkan.cpp @@ -261,12 +261,20 @@ Buffer createBuffer(size_t count, size_t dtypeSize, bool readOnly) { auto memReq = ctx.device.getBufferMemoryRequirements(buf.buffer); auto memProps = ctx.physicalDevice.getMemoryProperties(); - uint32_t memTypeIndex = 0; + uint32_t memTypeIndex = UINT32_MAX; + // Prefer device-local memory for GPU performance for (uint32_t i = 0; i < memProps.memoryTypeCount; i++) { if ((memReq.memoryTypeBits & (1u< buf.size) throw std::out_of_range("write out of range"); + if (bytes == 0) return; + + // Create a temporary host-visible staging buffer for upload + vk::BufferCreateInfo bci({}, bytes, vk::BufferUsageFlagBits::eTransferSrc); + vk::Buffer staging = ctx.device.createBuffer(bci); + auto memReq = ctx.device.getBufferMemoryRequirements(staging); + + auto memProps = ctx.physicalDevice.getMemoryProperties(); + uint32_t memTypeIndex = UINT32_MAX; + // Prefer HOST_VISIBLE | HOST_COHERENT for simple map without flush; fallback to HOST_VISIBLE + for (uint32_t i = 0; i < memProps.memoryTypeCount; i++) { + auto f = memProps.memoryTypes[i].propertyFlags; + if ((memReq.memoryTypeBits & (1u<(p) + (offset - mapOff), src, bytes); - vk::MappedMemoryRange rng(buf.memory, mapOff, mapSz); - ctx.device.flushMappedMemoryRanges(rng); // needed if memory not coherent - ctx.device.unmapMemory(buf.memory); + // Destroy staging resources + ctx.device.destroyBuffer(staging); + ctx.device.freeMemory(stagingMem); } void getBufferData(const Buffer& buf, void* dst, size_t bytes, size_t offset) { auto& ctx = getVulkanContext(); if (offset + bytes > buf.size) throw std::out_of_range("read out of range"); + if (bytes == 0) return; + + // Create a temporary host-visible staging buffer for download + vk::BufferCreateInfo bci({}, bytes, vk::BufferUsageFlagBits::eTransferDst); + vk::Buffer staging = ctx.device.createBuffer(bci); + auto memReq = ctx.device.getBufferMemoryRequirements(staging); + + auto memProps = ctx.physicalDevice.getMemoryProperties(); + uint32_t memTypeIndex = UINT32_MAX; + // Prefer HOST_VISIBLE | HOST_COHERENT; fallback to HOST_VISIBLE + for (uint32_t i = 0; i < memProps.memoryTypeCount; i++) { + auto f = memProps.memoryTypes[i].propertyFlags; + if ((memReq.memoryTypeBits & (1u<(p) + (offset - mapOff), bytes); - ctx.device.unmapMemory(buf.memory); + // Destroy staging resources + ctx.device.destroyBuffer(staging); + ctx.device.freeMemory(stagingMem); } VulkanContext::~VulkanContext() { From e95c500f1776f30739234f5b6ab29a3da74a0ad5 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Sat, 8 Nov 2025 18:28:29 +0100 Subject: [PATCH 37/44] Update sort example --- examples/radix_sort/__main__.py | 36 ++++++----- .../radix_sort/shaders/validate_sorted.slang | 44 ++++++++++++++ examples/radix_sort/sort.py | 60 +++++++++++++++++-- 3 files changed, 121 insertions(+), 19 deletions(-) create mode 100644 examples/radix_sort/shaders/validate_sorted.slang diff --git a/examples/radix_sort/__main__.py b/examples/radix_sort/__main__.py index ea3931b2..e141c6b4 100644 --- a/examples/radix_sort/__main__.py +++ b/examples/radix_sort/__main__.py @@ -94,12 +94,10 @@ def main() -> None: else np.empty(0, dtype=np.uint32) ) + # GPU-side validation: we'll run a one-time kernel check that counts out-of-order adjacent pairs. + # This avoids reading back entire arrays for CPU comparison. reference_keys = None reference_values = None - if count: - order = np.argsort(keys, kind="stable") - reference_keys = keys[order] - reference_values = values[order] frame_times = deque(maxlen=history_length) sort_times = deque(maxlen=history_length) @@ -137,9 +135,20 @@ def main() -> None: frame_time_total = sum(frame_times) fps = (len(frame_times) / frame_time_total) if frame_times and frame_time_total > 0.0 else 0.0 - sorted_keys, sorted_values = sorter.sort(keys, values, collect_stage_timings=True) + # Perform sort; on the first run also validate on GPU and avoid full array readback. + do_validate = not validated + _keys_out, _vals_out = sorter.sort( + keys, + values, + collect_stage_timings=True, + validate=do_validate, + return_arrays=False, + ) stage_timings = sorter.last_stage_timings or {} - kernel_time = float(sum(stage_timings.values())) + # Separate total_pass (overall) from per-stage summed time + total_pass_time = float(stage_timings.get("total_pass", 0.0)) + stage_sum_time = float(sum(v for k, v in stage_timings.items() if k != "total_pass")) + kernel_time = total_pass_time if total_pass_time > 0.0 else stage_sum_time if sort_times.maxlen == len(sort_times): sort_times.popleft() @@ -150,18 +159,15 @@ def main() -> None: for name in _STAGE_NAMES: stage_totals_overall[name] += stage_timings.get(name, 0.0) - if not validated and reference_keys is not None and reference_values is not None: - key_match = np.allclose(sorted_keys, reference_keys, atol=0.0, rtol=0.0) - value_match = np.array_equal(sorted_values, reference_values) - validation_ok = bool(key_match and value_match) + if do_validate: + errors = int(getattr(sorter, "last_validation_errors", 0) or 0) + validation_ok = (errors == 0) validation_message = ( - "GPU results match CPU reference." if validation_ok else "Mismatch detected against CPU reference!" + "GPU validation passed (sorted)." if validation_ok else f"GPU validation failed: {errors} out-of-order pairs" ) validated = True - # Release arrays promptly once consumed. - del sorted_keys - del sorted_values + # No large array readback performed when return_arrays=False. window_avg_sort = (sum(sort_times) / len(sort_times)) if sort_times else 0.0 last_sort_ms = kernel_time * 1000.0 @@ -189,6 +195,8 @@ def main() -> None: if sort_count: window.imgui_text(f"Last kernel time: {last_sort_ms:7.3f} ms") + if total_pass_time > 0.0: + window.imgui_text(f"Total pass time (last): {total_pass_time * 1000.0:7.3f} ms") if window_avg_sort > 0.0: window.imgui_text(f"Window avg kernel: {avg_sort_ms:7.3f} ms") if overall_avg_sort > 0.0: diff --git a/examples/radix_sort/shaders/validate_sorted.slang b/examples/radix_sort/shaders/validate_sorted.slang new file mode 100644 index 00000000..c709bbbb --- /dev/null +++ b/examples/radix_sort/shaders/validate_sorted.slang @@ -0,0 +1,44 @@ +static const uint TYPE_UINT = 0u; +static const uint TYPE_INT = 1u; +static const uint TYPE_FLOAT = 2u; +static const uint SIGN_BIT = 0x80000000u; +static const uint FULL_MASK = 0xFFFFFFFFu; +static const uint GROUP_SIZE = TF_GROUP_SIZE; + +[[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); // [0] = count, [1] = typeCode +[[vk::binding(1,0)]] StructuredBuffer Keys : register(t1, space0); // original key bit patterns +[[vk::binding(2,0)]] RWStructuredBuffer Errors : register(u2, space0); // single uint atomic counter + +// Map a key to sortable uint (same transform as map_to_uint.slang) +uint mapKey(uint raw, uint typeCode) +{ + if (typeCode == TYPE_INT) + { + raw ^= SIGN_BIT; + } + else if (typeCode == TYPE_FLOAT) + { + uint mask = ((raw >> 31) == 1u) ? FULL_MASK : SIGN_BIT; + raw ^= mask; + } + return raw; +} + +[numthreads(GROUP_SIZE, 1, 1)] +void csValidate(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint i = dispatchThreadID.x; + uint n = Params[0]; + if (n < 2 || i >= n - 1) + return; + + uint typeCode = Params[1]; + + uint a = mapKey(Keys[i], typeCode); + uint b = mapKey(Keys[i + 1u], typeCode); + + if (a > b) + { + InterlockedAdd(Errors[0], 1u); + } +} \ No newline at end of file diff --git a/examples/radix_sort/sort.py b/examples/radix_sort/sort.py index b79cbd7e..22dd11a5 100644 --- a/examples/radix_sort/sort.py +++ b/examples/radix_sort/sort.py @@ -88,6 +88,7 @@ def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: self.group_size = group_size self.histogram_size = 1 << bits_per_pass self.last_stage_timings = None + self.last_validation_errors: Optional[int] = None def inject_defines(filename: str, *, with_group: bool = False, with_histogram: bool = False) -> str: defines = [] @@ -166,6 +167,15 @@ def inject_defines(filename: str, *, with_group: bool = False, with_histogram: b rw_count=2, ) + # Validation program: checks if adjacent key pairs are sorted (after mapping back to original type) + self._validate_program = tf.createComputeProgramFromSlang( + "radix_validate_sorted", + inject_defines("validate_sorted.slang", with_group=True), + "csValidate", + ro_count=2, + rw_count=1, + ) + self._dummy_values_buffer = tf.createBuffer(1, 4, False) def close(self) -> None: @@ -179,6 +189,7 @@ def close(self) -> None: self._prefix_accum_program, self._bucket_scan_program, self._scatter_program, + self._validate_program, ): if program is not None: program.release() @@ -192,6 +203,8 @@ def sort( *, max_bits: int = 32, collect_stage_timings: bool = False, + validate: bool = False, + return_arrays: bool = True, ) -> Tuple[np.ndarray, Optional[np.ndarray]]: keys_array, key_dtype, key_kind = _prepare_keys(keys) element_count = int(keys_array.shape[0]) @@ -207,6 +220,8 @@ def sort( if element_count == 0: empty_keys = keys_array.copy() self.last_stage_timings = {} if collect_stage_timings else None + if validate: + self.last_validation_errors = 0 if values_array is None: return empty_keys, None return empty_keys, values_array.copy() @@ -298,6 +313,8 @@ def sort( scatter_groups = num_groups histogram_groups = num_groups + # Total pass timer starts at the first dispatch and ends after the last (map_from_uint) + total_start = time.perf_counter() if collect_stage_timings else None start = time.perf_counter() if collect_stage_timings else None self._map_to_uint_program.run( [map_buffer, key_buffers[0]], @@ -391,15 +408,48 @@ def sort( if collect_stage_timings and start is not None: stage_totals["map_from_uint"] += time.perf_counter() - start - sorted_keys = key_out.getData(key_dtype, element_count) - if values_array is not None and values_dtype is not None: - sorted_values = val_in.getData(values_dtype, element_count) + # Record total time from first to last dispatch in the sort pass + if collect_stage_timings and total_start is not None: + stage_totals["total_pass"] = time.perf_counter() - total_start + + # Optional GPU-side validation: check adjacent pairs and atomically count violations + if validate: + validate_params = np.zeros(2, dtype=np.uint32) + validate_params[0] = np.uint32(element_count) + validate_params[1] = map_params[1] # type code + + validate_params_buf = tf.createBuffer(validate_params.size, 4, True) + stack.callback(validate_params_buf.release) + validate_params_buf.setData(validate_params) + + error_buf = tf.createBuffer(1, 4, False) + stack.callback(error_buf.release) + error_zero = np.zeros(1, dtype=np.uint32) + error_buf.setData(error_zero) + + # Reuse map_groups; kernel early-outs for i >= n-1 + self._validate_program.run( + [validate_params_buf, key_out], + [error_buf], + map_groups, + ) + + error_count = int(error_buf.getData(np.dtype(np.uint32), 1)[0]) + self.last_validation_errors = error_count + + if return_arrays: + sorted_keys = key_out.getData(key_dtype, element_count) + if values_array is not None and values_dtype is not None: + sorted_values = val_in.getData(values_dtype, element_count) + else: + sorted_values = None else: - sorted_values = None + # Avoid full readback when not needed by caller + sorted_keys = np.empty(0, dtype=key_dtype) + sorted_values = (np.empty(0, dtype=values_dtype) if values_array is not None and values_dtype is not None else None) self.last_stage_timings = stage_totals if collect_stage_timings else None return sorted_keys, sorted_values - return sorted_keys, sorted_values, None _SORTER_CACHE: Dict[_SorterKey, HistogramRadixSort] = {} From 7a213495d7f6a55ef917669bc00ec60075690143 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Sat, 8 Nov 2025 22:13:40 +0100 Subject: [PATCH 38/44] Update PybindModule.cpp --- TensorFrost/PybindModule.cpp | 84 ++++++++++++++++++------------------ 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/TensorFrost/PybindModule.cpp b/TensorFrost/PybindModule.cpp index 80e7c129..8b88e2f5 100644 --- a/TensorFrost/PybindModule.cpp +++ b/TensorFrost/PybindModule.cpp @@ -132,48 +132,48 @@ PYBIND11_MODULE(TensorFrost, m) { py::print("TensorFrost module loaded in debug mode! Expect slow performance."); #endif - // TEST CODE - TFProgram program([]() -> auto { - Values inputs; - Values outputs; - Value a = 5; - Value b = 10; - Value f = 2.5f; - Value g = 3.5f; - Value c = a + b * 3; - Value mem = memory({a, b, c}, TFFloat32); - inputs.push_back(mem); - vmap({a, b, c}, [&](Values ids0) { - Value something = tofloat(mem * sin(f + g)); - outputs.push_back(something); - }); - vmap({a, b, c}, [&](Values ids0) { - Value imem = toint(mem * sin(f + g)); - Value d = c + b + ids0[0] * imem; - Value m0; - vmap({c}, [&](Values ids1) { - m0 = 0; - }); - if_cond(d > 0, [&]() { - Value t = d * c * imem; - vmap({c}, [&](Values ids1) { - m0.Set(t * imem[{ids0[1], ids0[1], ids0[1]}]); - }); - }, [&]() { - Value t = d * c / imem; - vmap({c}, [&](Values ids1) { - m0.Set(t / imem[{ids1[0], ids0[0], ids0[1]}]); - }); - }); - vmap({c, c}, [&](Values ids1) { - Value m = m0 * imem[{ids1[1], ids1[0], ids0[0]}]; - outputs.push_back(m); - }); - }); - return std::make_pair(inputs, outputs); - }); - program.Compile(); - py::print(program.DebugPrint()); + // // TEST CODE + // TFProgram program([]() -> auto { + // Values inputs; + // Values outputs; + // Value a = 5; + // Value b = 10; + // Value f = 2.5f; + // Value g = 3.5f; + // Value c = a + b * 3; + // Value mem = memory({a, b, c}, TFFloat32); + // inputs.push_back(mem); + // vmap({a, b, c}, [&](Values ids0) { + // Value something = tofloat(mem * sin(f + g)); + // outputs.push_back(something); + // }); + // vmap({a, b, c}, [&](Values ids0) { + // Value imem = toint(mem * sin(f + g)); + // Value d = c + b + ids0[0] * imem; + // Value m0; + // vmap({c}, [&](Values ids1) { + // m0 = 0; + // }); + // if_cond(d > 0, [&]() { + // Value t = d * c * imem; + // vmap({c}, [&](Values ids1) { + // m0.Set(t * imem[{ids0[1], ids0[1], ids0[1]}]); + // }); + // }, [&]() { + // Value t = d * c / imem; + // vmap({c}, [&](Values ids1) { + // m0.Set(t / imem[{ids1[0], ids0[0], ids0[1]}]); + // }); + // }); + // vmap({c, c}, [&](Values ids1) { + // Value m = m0 * imem[{ids1[1], ids1[0], ids0[0]}]; + // outputs.push_back(m); + // }); + // }); + // return std::make_pair(inputs, outputs); + // }); + // program.Compile(); + // py::print(program.DebugPrint()); VulkanDefinitions(m); From 822dc599d7965f05dfca81c17348cbfcd53e9529 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Sat, 8 Nov 2025 22:17:44 +0100 Subject: [PATCH 39/44] Move legacy tests --- tests/{ => legacy}/autograd_test.py | 0 tests/{ => legacy}/linalg_test.py | 0 tests/{ => legacy}/reshape_reduction_test.py | 0 tests/{ => legacy}/sorting_opengl_test.py | 0 tests/{ => legacy}/sorting_test.py | 0 tests/{ => legacy}/split_dim_test.py | 0 6 files changed, 0 insertions(+), 0 deletions(-) rename tests/{ => legacy}/autograd_test.py (100%) rename tests/{ => legacy}/linalg_test.py (100%) rename tests/{ => legacy}/reshape_reduction_test.py (100%) rename tests/{ => legacy}/sorting_opengl_test.py (100%) rename tests/{ => legacy}/sorting_test.py (100%) rename tests/{ => legacy}/split_dim_test.py (100%) diff --git a/tests/autograd_test.py b/tests/legacy/autograd_test.py similarity index 100% rename from tests/autograd_test.py rename to tests/legacy/autograd_test.py diff --git a/tests/linalg_test.py b/tests/legacy/linalg_test.py similarity index 100% rename from tests/linalg_test.py rename to tests/legacy/linalg_test.py diff --git a/tests/reshape_reduction_test.py b/tests/legacy/reshape_reduction_test.py similarity index 100% rename from tests/reshape_reduction_test.py rename to tests/legacy/reshape_reduction_test.py diff --git a/tests/sorting_opengl_test.py b/tests/legacy/sorting_opengl_test.py similarity index 100% rename from tests/sorting_opengl_test.py rename to tests/legacy/sorting_opengl_test.py diff --git a/tests/sorting_test.py b/tests/legacy/sorting_test.py similarity index 100% rename from tests/sorting_test.py rename to tests/legacy/sorting_test.py diff --git a/tests/split_dim_test.py b/tests/legacy/split_dim_test.py similarity index 100% rename from tests/split_dim_test.py rename to tests/legacy/split_dim_test.py From bd3e2dea7ac9956fbc51bb5fa04e51ff24faef6c Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Sun, 9 Nov 2025 03:17:15 +0100 Subject: [PATCH 40/44] Remove manual releases. --- Python/TensorFrost/sort.py | 236 +++++++++++------------ examples/Slang/mandelbrot.py | 5 +- examples/radix_sort/__main__.py | 16 +- examples/radix_sort/sort.py | 322 +++++++++++++++----------------- tests/slang_compile_test.py | 54 +++--- tests/vulkan_window_test.py | 24 ++- 6 files changed, 315 insertions(+), 342 deletions(-) diff --git a/Python/TensorFrost/sort.py b/Python/TensorFrost/sort.py index 60845e5c..5148eca9 100644 --- a/Python/TensorFrost/sort.py +++ b/Python/TensorFrost/sort.py @@ -1,6 +1,5 @@ from __future__ import annotations -from contextlib import ExitStack from dataclasses import dataclass from importlib import resources from typing import Dict, Optional, Tuple @@ -156,21 +155,19 @@ def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: self._dummy_values_buffer = tf.createBuffer(1, 4, False) def close(self) -> None: - for program in ( - self._map_to_uint_program, - self._map_from_uint_program, - self._histogram_program, - self._unpack_program, - self._prefix_local_program, - self._prefix_blocks_program, - self._prefix_accum_program, - self._bucket_scan_program, - self._scatter_program, + for attr in ( + "_map_to_uint_program", + "_map_from_uint_program", + "_histogram_program", + "_unpack_program", + "_prefix_local_program", + "_prefix_blocks_program", + "_prefix_accum_program", + "_bucket_scan_program", + "_scatter_program", ): - if program is not None: - program.release() - if self._dummy_values_buffer is not None: - self._dummy_values_buffer.release() + setattr(self, attr, None) + self._dummy_values_buffer = None def sort( self, @@ -218,127 +215,110 @@ def sort( map_params[0] = np.uint32(element_count) map_params[1] = _TYPE_CODES[key_kind] - with ExitStack() as stack: - params_buffer = tf.createBuffer(params_array.size, 4, True) - stack.callback(params_buffer.release) + params_buffer = tf.createBuffer(params_array.size, 4, True) + params_buffer.setData(params_array) + + map_buffer = tf.createBuffer(map_params.size, 4, True) + map_buffer.setData(map_params) + + key_buffers = [tf.createBuffer(max(element_count, 1), 4, False) for _ in range(2)] + key_buffers[0].setData(keys_array) + + if values_array is not None: + value_buffers = [tf.createBuffer(max(element_count, 1), 4, False) for _ in range(2)] + value_buffers[0].setData(values_array) + else: + dummy = self._dummy_values_buffer + value_buffers = [dummy, dummy] + + packed_hist_buffer = tf.createBuffer(max(packed_count * num_groups, 1), 4, False) + group_hist_buffer = tf.createBuffer(max(histogram_size * num_groups, 1), 4, False) + prefix_buffer = tf.createBuffer(max(histogram_size * num_groups, 1), 4, False) + block_totals_buffer = tf.createBuffer(max(histogram_size * block_count, 1), 4, False) + block_prefix_buffer = tf.createBuffer(max(histogram_size * block_count, 1), 4, False) + bucket_scan_buffer = tf.createBuffer(max(histogram_size, 1), 4, False) + + map_groups = _dispatch_groups(element_count, self.group_size) + reduction_group_size = 64 + unpack_groups = _dispatch_groups(histogram_size * num_groups, reduction_group_size) + prefix_local_groups = _dispatch_groups(histogram_size * block_count, reduction_group_size) + prefix_block_groups = _dispatch_groups(histogram_size, reduction_group_size) + prefix_accum_groups = _dispatch_groups(histogram_size * block_count, reduction_group_size) + bucket_scan_groups = _dispatch_groups(histogram_size, reduction_group_size) + scatter_groups = num_groups + histogram_groups = num_groups + + self._map_to_uint_program.run( + [map_buffer, key_buffers[0]], + [key_buffers[1]], + map_groups, + ) + + key_in = key_buffers[1] + key_out = key_buffers[0] + val_in, val_out = value_buffers + + for pass_index in range(passes): + params_array[2] = np.uint32(pass_index * self.bits_per_pass) params_buffer.setData(params_array) - map_buffer = tf.createBuffer(map_params.size, 4, True) - stack.callback(map_buffer.release) - map_buffer.setData(map_params) + self._histogram_program.run( + [params_buffer, key_in], + [packed_hist_buffer], + histogram_groups, + ) - key_buffers = [tf.createBuffer(max(element_count, 1), 4, False) for _ in range(2)] - for buf in key_buffers: - stack.callback(buf.release) - key_buffers[0].setData(keys_array) + self._unpack_program.run( + [params_buffer, packed_hist_buffer], + [group_hist_buffer], + unpack_groups, + ) - if values_array is not None: - value_buffers = [tf.createBuffer(max(element_count, 1), 4, False) for _ in range(2)] - for buf in value_buffers: - stack.callback(buf.release) - value_buffers[0].setData(values_array) - else: - dummy = self._dummy_values_buffer - value_buffers = [dummy, dummy] - - packed_hist_buffer = tf.createBuffer(max(packed_count * num_groups, 1), 4, False) - group_hist_buffer = tf.createBuffer(max(histogram_size * num_groups, 1), 4, False) - prefix_buffer = tf.createBuffer(max(histogram_size * num_groups, 1), 4, False) - block_totals_buffer = tf.createBuffer(max(histogram_size * block_count, 1), 4, False) - block_prefix_buffer = tf.createBuffer(max(histogram_size * block_count, 1), 4, False) - bucket_scan_buffer = tf.createBuffer(max(histogram_size, 1), 4, False) - - for buf in ( - packed_hist_buffer, - group_hist_buffer, - prefix_buffer, - block_totals_buffer, - block_prefix_buffer, - bucket_scan_buffer, - ): - stack.callback(buf.release) - - map_groups = _dispatch_groups(element_count, self.group_size) - reduction_group_size = 64 - unpack_groups = _dispatch_groups(histogram_size * num_groups, reduction_group_size) - prefix_local_groups = _dispatch_groups(histogram_size * block_count, reduction_group_size) - prefix_block_groups = _dispatch_groups(histogram_size, reduction_group_size) - prefix_accum_groups = _dispatch_groups(histogram_size * block_count, reduction_group_size) - bucket_scan_groups = _dispatch_groups(histogram_size, reduction_group_size) - scatter_groups = num_groups - histogram_groups = num_groups - - self._map_to_uint_program.run( - [map_buffer, key_buffers[0]], - [key_buffers[1]], - map_groups, + self._prefix_local_program.run( + [params_buffer, group_hist_buffer], + [prefix_buffer, block_totals_buffer], + prefix_local_groups, ) - key_in = key_buffers[1] - key_out = key_buffers[0] - val_in, val_out = value_buffers - - for pass_index in range(passes): - params_array[2] = np.uint32(pass_index * self.bits_per_pass) - params_buffer.setData(params_array) - - self._histogram_program.run( - [params_buffer, key_in], - [packed_hist_buffer], - histogram_groups, - ) - - self._unpack_program.run( - [params_buffer, packed_hist_buffer], - [group_hist_buffer], - unpack_groups, - ) - - self._prefix_local_program.run( - [params_buffer, group_hist_buffer], - [prefix_buffer, block_totals_buffer], - prefix_local_groups, - ) - - self._prefix_blocks_program.run( - [params_buffer, block_totals_buffer], - [block_prefix_buffer], - prefix_block_groups, - ) - - self._prefix_accum_program.run( - [params_buffer, block_prefix_buffer], - [prefix_buffer], - prefix_accum_groups, - ) - - self._bucket_scan_program.run( - [params_buffer, prefix_buffer], - [bucket_scan_buffer], - bucket_scan_groups, - ) - - self._scatter_program.run( - [params_buffer, key_in, val_in, prefix_buffer, bucket_scan_buffer], - [key_out, val_out], - scatter_groups, - ) - - key_in, key_out = key_out, key_in - if values_array is not None: - val_in, val_out = val_out, val_in - - self._map_from_uint_program.run( - [map_buffer, key_in], - [key_out], - map_groups, + self._prefix_blocks_program.run( + [params_buffer, block_totals_buffer], + [block_prefix_buffer], + prefix_block_groups, ) - sorted_keys = key_out.getData(key_dtype, element_count) - if values_array is not None and values_dtype is not None: - sorted_values = val_in.getData(values_dtype, element_count) - else: - sorted_values = None + self._prefix_accum_program.run( + [params_buffer, block_prefix_buffer], + [prefix_buffer], + prefix_accum_groups, + ) + + self._bucket_scan_program.run( + [params_buffer, prefix_buffer], + [bucket_scan_buffer], + bucket_scan_groups, + ) + + self._scatter_program.run( + [params_buffer, key_in, val_in, prefix_buffer, bucket_scan_buffer], + [key_out, val_out], + scatter_groups, + ) + + key_in, key_out = key_out, key_in + if values_array is not None: + val_in, val_out = val_out, val_in + + self._map_from_uint_program.run( + [map_buffer, key_in], + [key_out], + map_groups, + ) + + sorted_keys = key_out.getData(key_dtype, element_count) + if values_array is not None and values_dtype is not None: + sorted_values = val_in.getData(values_dtype, element_count) + else: + sorted_values = None return sorted_keys, sorted_values diff --git a/examples/Slang/mandelbrot.py b/examples/Slang/mandelbrot.py index b26999bd..dd404a0f 100644 --- a/examples/Slang/mandelbrot.py +++ b/examples/Slang/mandelbrot.py @@ -48,7 +48,6 @@ def ensure_pixel_buffer(cur_width: int, cur_height: int) -> None: nonlocal pixel_buffer, pixel_capacity required = max(1, cur_width * cur_height) if required != pixel_capacity: - pixel_buffer.release() pixel_buffer = tf.createBuffer(required, 4, False) pixel_capacity = required @@ -148,8 +147,8 @@ def ensure_pixel_buffer(cur_width: int, cur_height: int) -> None: pending_scroll += scroll_dy finally: win.close() - pixel_buffer.release() - params_buffer.release() + pixel_buffer = None + params_buffer = None if __name__ == "__main__": diff --git a/examples/radix_sort/__main__.py b/examples/radix_sort/__main__.py index e141c6b4..19601839 100644 --- a/examples/radix_sort/__main__.py +++ b/examples/radix_sort/__main__.py @@ -4,7 +4,6 @@ import math import time from collections import deque -from contextlib import ExitStack import numpy as np import TensorFrost as tf @@ -110,22 +109,22 @@ def main() -> None: sort_count = 0 total_kernel_time = 0.0 - with ExitStack() as stack: + sorter = None + window = None + try: sorter = HistogramRadixSort(bits_per_pass=bits_per_pass) - stack.callback(sorter.close) window_title = f"TensorFrost Radix Sort ({count:,} elements)" window = tf.createWindow(window_width, window_height, window_title) - stack.callback(window.close) - if font_scale > 0.0: + if font_scale > 0.0 and window is not None: window.imgui_scale_all_sizes(font_scale) window.imgui_set_font_global_scale(font_scale) start_time = time.perf_counter() last_frame_time = start_time - while window.isOpen(): + while window is not None and window.isOpen(): now = time.perf_counter() dt = now - last_frame_time last_frame_time = now @@ -258,6 +257,11 @@ def main() -> None: window.imgui_end() window.present() + finally: + if window is not None: + window.close() + if sorter is not None: + sorter.close() if __name__ == "__main__": diff --git a/examples/radix_sort/sort.py b/examples/radix_sort/sort.py index 22dd11a5..fe0104e6 100644 --- a/examples/radix_sort/sort.py +++ b/examples/radix_sort/sort.py @@ -1,6 +1,5 @@ from __future__ import annotations -from contextlib import ExitStack from dataclasses import dataclass from pathlib import Path from typing import Dict, Optional, Tuple @@ -263,190 +262,171 @@ def sort( else: stage_totals = {} - with ExitStack() as stack: - params_buffer = tf.createBuffer(params_array.size, 4, True) - stack.callback(params_buffer.release) + params_buffer = tf.createBuffer(params_array.size, 4, True) + params_buffer.setData(params_array) + + map_buffer = tf.createBuffer(map_params.size, 4, True) + map_buffer.setData(map_params) + + key_buffers = [tf.createBuffer(max(element_count, 1), 4, False) for _ in range(2)] + key_buffers[0].setData(keys_array) + + if values_array is not None: + value_buffers = [tf.createBuffer(max(element_count, 1), 4, False) for _ in range(2)] + value_buffers[0].setData(values_array) + else: + dummy = self._dummy_values_buffer + value_buffers = [dummy, dummy] + + packed_hist_buffer = tf.createBuffer(max(packed_count * num_groups, 1), 4, False) + group_hist_buffer = tf.createBuffer(max(histogram_size * num_groups, 1), 4, False) + prefix_buffer = tf.createBuffer(max(histogram_size * num_groups, 1), 4, False) + block_totals_buffer = tf.createBuffer(max(histogram_size * block_count, 1), 4, False) + block_prefix_buffer = tf.createBuffer(max(histogram_size * block_count, 1), 4, False) + bucket_scan_buffer = tf.createBuffer(max(histogram_size, 1), 4, False) + + map_groups = _dispatch_groups(element_count, self.group_size) + reduction_group_size = 64 + unpack_groups = _dispatch_groups(histogram_size * num_groups, reduction_group_size) + prefix_local_groups = _dispatch_groups(histogram_size * block_count, reduction_group_size) + prefix_block_groups = _dispatch_groups(histogram_size, reduction_group_size) + prefix_accum_groups = _dispatch_groups(histogram_size * block_count, reduction_group_size) + bucket_scan_groups = _dispatch_groups(histogram_size, reduction_group_size) + scatter_groups = num_groups + histogram_groups = num_groups + + # Total pass timer starts at the first dispatch and ends after the last (map_from_uint) + total_start = time.perf_counter() if collect_stage_timings else None + start = time.perf_counter() if collect_stage_timings else None + self._map_to_uint_program.run( + [map_buffer, key_buffers[0]], + [key_buffers[1]], + map_groups, + ) + if collect_stage_timings and start is not None: + stage_totals["map_to_uint"] += time.perf_counter() - start + + key_in = key_buffers[1] + key_out = key_buffers[0] + val_in, val_out = value_buffers + + for pass_index in range(passes): + params_array[2] = np.uint32(pass_index * self.bits_per_pass) params_buffer.setData(params_array) - map_buffer = tf.createBuffer(map_params.size, 4, True) - stack.callback(map_buffer.release) - map_buffer.setData(map_params) + start = time.perf_counter() if collect_stage_timings else None + self._histogram_program.run( + [params_buffer, key_in], + [packed_hist_buffer], + histogram_groups, + ) + if collect_stage_timings and start is not None: + stage_totals["histogram"] += time.perf_counter() - start - key_buffers = [tf.createBuffer(max(element_count, 1), 4, False) for _ in range(2)] - for buf in key_buffers: - stack.callback(buf.release) - key_buffers[0].setData(keys_array) + start = time.perf_counter() if collect_stage_timings else None + self._unpack_program.run( + [params_buffer, packed_hist_buffer], + [group_hist_buffer], + unpack_groups, + ) + if collect_stage_timings and start is not None: + stage_totals["unpack"] += time.perf_counter() - start - if values_array is not None: - value_buffers = [tf.createBuffer(max(element_count, 1), 4, False) for _ in range(2)] - for buf in value_buffers: - stack.callback(buf.release) - value_buffers[0].setData(values_array) - else: - dummy = self._dummy_values_buffer - value_buffers = [dummy, dummy] - - packed_hist_buffer = tf.createBuffer(max(packed_count * num_groups, 1), 4, False) - group_hist_buffer = tf.createBuffer(max(histogram_size * num_groups, 1), 4, False) - prefix_buffer = tf.createBuffer(max(histogram_size * num_groups, 1), 4, False) - block_totals_buffer = tf.createBuffer(max(histogram_size * block_count, 1), 4, False) - block_prefix_buffer = tf.createBuffer(max(histogram_size * block_count, 1), 4, False) - bucket_scan_buffer = tf.createBuffer(max(histogram_size, 1), 4, False) - - for buf in ( - packed_hist_buffer, - group_hist_buffer, - prefix_buffer, - block_totals_buffer, - block_prefix_buffer, - bucket_scan_buffer, - ): - stack.callback(buf.release) - - map_groups = _dispatch_groups(element_count, self.group_size) - reduction_group_size = 64 - unpack_groups = _dispatch_groups(histogram_size * num_groups, reduction_group_size) - prefix_local_groups = _dispatch_groups(histogram_size * block_count, reduction_group_size) - prefix_block_groups = _dispatch_groups(histogram_size, reduction_group_size) - prefix_accum_groups = _dispatch_groups(histogram_size * block_count, reduction_group_size) - bucket_scan_groups = _dispatch_groups(histogram_size, reduction_group_size) - scatter_groups = num_groups - histogram_groups = num_groups - - # Total pass timer starts at the first dispatch and ends after the last (map_from_uint) - total_start = time.perf_counter() if collect_stage_timings else None start = time.perf_counter() if collect_stage_timings else None - self._map_to_uint_program.run( - [map_buffer, key_buffers[0]], - [key_buffers[1]], - map_groups, + self._prefix_local_program.run( + [params_buffer, group_hist_buffer], + [prefix_buffer, block_totals_buffer], + prefix_local_groups, ) if collect_stage_timings and start is not None: - stage_totals["map_to_uint"] += time.perf_counter() - start - - key_in = key_buffers[1] - key_out = key_buffers[0] - val_in, val_out = value_buffers - - for pass_index in range(passes): - params_array[2] = np.uint32(pass_index * self.bits_per_pass) - params_buffer.setData(params_array) - - start = time.perf_counter() if collect_stage_timings else None - self._histogram_program.run( - [params_buffer, key_in], - [packed_hist_buffer], - histogram_groups, - ) - if collect_stage_timings and start is not None: - stage_totals["histogram"] += time.perf_counter() - start - - start = time.perf_counter() if collect_stage_timings else None - self._unpack_program.run( - [params_buffer, packed_hist_buffer], - [group_hist_buffer], - unpack_groups, - ) - if collect_stage_timings and start is not None: - stage_totals["unpack"] += time.perf_counter() - start - - start = time.perf_counter() if collect_stage_timings else None - self._prefix_local_program.run( - [params_buffer, group_hist_buffer], - [prefix_buffer, block_totals_buffer], - prefix_local_groups, - ) - if collect_stage_timings and start is not None: - stage_totals["prefix_local"] += time.perf_counter() - start - - start = time.perf_counter() if collect_stage_timings else None - self._prefix_blocks_program.run( - [params_buffer, block_totals_buffer], - [block_prefix_buffer], - prefix_block_groups, - ) - if collect_stage_timings and start is not None: - stage_totals["prefix_blocks"] += time.perf_counter() - start - - start = time.perf_counter() if collect_stage_timings else None - self._prefix_accum_program.run( - [params_buffer, block_prefix_buffer], - [prefix_buffer], - prefix_accum_groups, - ) - if collect_stage_timings and start is not None: - stage_totals["prefix_accum"] += time.perf_counter() - start - - start = time.perf_counter() if collect_stage_timings else None - self._bucket_scan_program.run( - [params_buffer, prefix_buffer], - [bucket_scan_buffer], - bucket_scan_groups, - ) - if collect_stage_timings and start is not None: - stage_totals["bucket_scan"] += time.perf_counter() - start - - start = time.perf_counter() if collect_stage_timings else None - self._scatter_program.run( - [params_buffer, key_in, val_in, prefix_buffer, bucket_scan_buffer], - [key_out, val_out], - scatter_groups, - ) - if collect_stage_timings and start is not None: - stage_totals["scatter"] += time.perf_counter() - start - - key_in, key_out = key_out, key_in - if values_array is not None: - val_in, val_out = val_out, val_in + stage_totals["prefix_local"] += time.perf_counter() - start start = time.perf_counter() if collect_stage_timings else None - self._map_from_uint_program.run( - [map_buffer, key_in], - [key_out], - map_groups, + self._prefix_blocks_program.run( + [params_buffer, block_totals_buffer], + [block_prefix_buffer], + prefix_block_groups, ) if collect_stage_timings and start is not None: - stage_totals["map_from_uint"] += time.perf_counter() - start + stage_totals["prefix_blocks"] += time.perf_counter() - start - # Record total time from first to last dispatch in the sort pass - if collect_stage_timings and total_start is not None: - stage_totals["total_pass"] = time.perf_counter() - total_start + start = time.perf_counter() if collect_stage_timings else None + self._prefix_accum_program.run( + [params_buffer, block_prefix_buffer], + [prefix_buffer], + prefix_accum_groups, + ) + if collect_stage_timings and start is not None: + stage_totals["prefix_accum"] += time.perf_counter() - start - # Optional GPU-side validation: check adjacent pairs and atomically count violations - if validate: - validate_params = np.zeros(2, dtype=np.uint32) - validate_params[0] = np.uint32(element_count) - validate_params[1] = map_params[1] # type code - - validate_params_buf = tf.createBuffer(validate_params.size, 4, True) - stack.callback(validate_params_buf.release) - validate_params_buf.setData(validate_params) - - error_buf = tf.createBuffer(1, 4, False) - stack.callback(error_buf.release) - error_zero = np.zeros(1, dtype=np.uint32) - error_buf.setData(error_zero) - - # Reuse map_groups; kernel early-outs for i >= n-1 - self._validate_program.run( - [validate_params_buf, key_out], - [error_buf], - map_groups, - ) - - error_count = int(error_buf.getData(np.dtype(np.uint32), 1)[0]) - self.last_validation_errors = error_count - - if return_arrays: - sorted_keys = key_out.getData(key_dtype, element_count) - if values_array is not None and values_dtype is not None: - sorted_values = val_in.getData(values_dtype, element_count) - else: - sorted_values = None + start = time.perf_counter() if collect_stage_timings else None + self._bucket_scan_program.run( + [params_buffer, prefix_buffer], + [bucket_scan_buffer], + bucket_scan_groups, + ) + if collect_stage_timings and start is not None: + stage_totals["bucket_scan"] += time.perf_counter() - start + + start = time.perf_counter() if collect_stage_timings else None + self._scatter_program.run( + [params_buffer, key_in, val_in, prefix_buffer, bucket_scan_buffer], + [key_out, val_out], + scatter_groups, + ) + if collect_stage_timings and start is not None: + stage_totals["scatter"] += time.perf_counter() - start + + key_in, key_out = key_out, key_in + if values_array is not None: + val_in, val_out = val_out, val_in + + start = time.perf_counter() if collect_stage_timings else None + self._map_from_uint_program.run( + [map_buffer, key_in], + [key_out], + map_groups, + ) + if collect_stage_timings and start is not None: + stage_totals["map_from_uint"] += time.perf_counter() - start + + # Record total time from first to last dispatch in the sort pass + if collect_stage_timings and total_start is not None: + stage_totals["total_pass"] = time.perf_counter() - total_start + + # Optional GPU-side validation: check adjacent pairs and atomically count violations + if validate: + validate_params = np.zeros(2, dtype=np.uint32) + validate_params[0] = np.uint32(element_count) + validate_params[1] = map_params[1] # type code + + validate_params_buf = tf.createBuffer(validate_params.size, 4, True) + validate_params_buf.setData(validate_params) + + error_buf = tf.createBuffer(1, 4, False) + error_zero = np.zeros(1, dtype=np.uint32) + error_buf.setData(error_zero) + + # Reuse map_groups; kernel early-outs for i >= n-1 + self._validate_program.run( + [validate_params_buf, key_out], + [error_buf], + map_groups, + ) + + error_count = int(error_buf.getData(np.dtype(np.uint32), 1)[0]) + self.last_validation_errors = error_count + + if return_arrays: + sorted_keys = key_out.getData(key_dtype, element_count) + if values_array is not None and values_dtype is not None: + sorted_values = val_in.getData(values_dtype, element_count) else: - # Avoid full readback when not needed by caller - sorted_keys = np.empty(0, dtype=key_dtype) - sorted_values = (np.empty(0, dtype=values_dtype) if values_array is not None and values_dtype is not None else None) + sorted_values = None + else: + # Avoid full readback when not needed by caller + sorted_keys = np.empty(0, dtype=key_dtype) + sorted_values = (np.empty(0, dtype=values_dtype) if values_array is not None and values_dtype is not None else None) self.last_stage_timings = stage_totals if collect_stage_timings else None return sorted_keys, sorted_values diff --git a/tests/slang_compile_test.py b/tests/slang_compile_test.py index 0f99e2f2..bb2c2f8a 100644 --- a/tests/slang_compile_test.py +++ b/tests/slang_compile_test.py @@ -1,5 +1,4 @@ import unittest -from contextlib import ExitStack from pathlib import Path import numpy as np @@ -75,39 +74,40 @@ def test_compile_and_execute_simple_shader(self) -> None: except RuntimeError as exc: # pragma: no cover - Vulkan not available self.skipTest(f"Vulkan buffer creation failed: {exc}") - with ExitStack() as resources: - resources.callback(readonly_buffer.release) + try: + readwrite_buffer = tf.createBuffer(thread_count, 4, False) + except RuntimeError as exc: # pragma: no cover - Vulkan not available + self.skipTest(f"Vulkan buffer creation failed: {exc}") - try: - readwrite_buffer = tf.createBuffer(thread_count, 4, False) - except RuntimeError as exc: # pragma: no cover - Vulkan not available - self.skipTest(f"Vulkan buffer creation failed: {exc}") + try: + program = tf.createComputeProgramFromSlang( + "tensorfrost_test_shader", + _SIMPLE_SLANG_SHADER, + "csMain", + ro_count=1, + rw_count=1, + ) + except RuntimeError as exc: + if _should_skip_for_backend(exc): # pragma: no cover - Vulkan not available + self.skipTest(f"Slang compilation backend unavailable: {exc}") + raise - resources.callback(readwrite_buffer.release) + readonly_buffer.setData(np.array([7], dtype=np.uint32)) + readwrite_buffer.setData(np.zeros(1, dtype=np.uint32)) - try: - program = tf.createComputeProgramFromSlang( - "tensorfrost_test_shader", - _SIMPLE_SLANG_SHADER, - "csMain", - ro_count=1, - rw_count=1, - ) - except RuntimeError as exc: - if _should_skip_for_backend(exc): # pragma: no cover - Vulkan not available - self.skipTest(f"Slang compilation backend unavailable: {exc}") - raise + program.run([readonly_buffer], [readwrite_buffer], group_count) - resources.callback(program.release) + result = readwrite_buffer.getData(np.dtype(np.uint32), thread_count) + self.assertEqual(result.shape, (thread_count,)) + self.assertEqual(int(result[0]), 8) - readonly_buffer.setData(np.array([7], dtype=np.uint32)) - readwrite_buffer.setData(np.zeros(1, dtype=np.uint32)) + readonly_buffer = None + readwrite_buffer = None + program = None - program.run([readonly_buffer], [readwrite_buffer], group_count) + import gc - result = readwrite_buffer.getData(np.dtype(np.uint32), thread_count) - self.assertEqual(result.shape, (thread_count,)) - self.assertEqual(int(result[0]), 8) + gc.collect() if __name__ == "__main__": diff --git a/tests/vulkan_window_test.py b/tests/vulkan_window_test.py index 004d78ad..9f871fd5 100644 --- a/tests/vulkan_window_test.py +++ b/tests/vulkan_window_test.py @@ -1,5 +1,4 @@ import unittest -from contextlib import ExitStack import numpy as np @@ -30,12 +29,14 @@ def test_compute_dispatch_and_window_present(self): except RuntimeError as exc: # pragma: no cover - Vulkan not available self.skipTest(f"Vulkan buffer creation failed: {exc}") - with ExitStack() as resources: - resources.callback(pixel_buffer.release) - + program = None + try: program = tf.createComputeProgramFromGLSL(_SIMPLE_GLSL, ro_count=0, rw_count=1) - resources.callback(program.release) + except RuntimeError as exc: # pragma: no cover - Vulkan program creation failed + self.skipTest(f"Vulkan program creation failed: {exc}") + window = None + try: program.run([], [pixel_buffer], group_count) pixels = pixel_buffer.getData(np.dtype(np.uint32), thread_count) self.assertTrue(np.all(pixels == 0xFF3366FF), "Compute shader did not write expected color") @@ -45,11 +46,20 @@ def test_compute_dispatch_and_window_present(self): except RuntimeError as exc: # pragma: no cover - Vulkan window not available self.skipTest(f"Vulkan window creation failed: {exc}") - resources.callback(window.close) - # Present once to ensure the binding path is exercised. window.drawBuffer(pixel_buffer, width, height) self.assertTrue(window.isOpen(), "Window should report as open after initial present") + finally: + if window is not None: + window.close() + window = None + program = None + pixel_buffer = None + + # Ensure GPU resources are released before interpreter shutdown. + import gc + + gc.collect() if __name__ == "__main__": From dd1f6861f193145bb6029cf77f138bbac93fc106 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Sun, 9 Nov 2025 03:29:07 +0100 Subject: [PATCH 41/44] Remove more useless code --- examples/Slang/mandelbrot.py | 187 ++++++++++----------- examples/debug.py | 11 +- examples/imgui_showcase.py | 7 +- examples/radix_sort/__main__.py | 286 ++++++++++++++++---------------- examples/radix_sort/sort.py | 18 -- tests/imgui_test.py | 42 +++-- tests/vulkan_window_test.py | 35 ++-- 7 files changed, 282 insertions(+), 304 deletions(-) diff --git a/examples/Slang/mandelbrot.py b/examples/Slang/mandelbrot.py index dd404a0f..6803d450 100644 --- a/examples/Slang/mandelbrot.py +++ b/examples/Slang/mandelbrot.py @@ -51,104 +51,99 @@ def ensure_pixel_buffer(cur_width: int, cur_height: int) -> None: pixel_buffer = tf.createBuffer(required, 4, False) pixel_capacity = required - try: - while win.isOpen(): - now = time.perf_counter() - dt = max(now - prev_time, 1e-6) - prev_time = now - fps = fps * 0.9 + (1.0 / dt) * 0.1 - - width, height = win.size - width = max(1, int(width)) - height = max(1, int(height)) - thread_count = max(1, width * height) - group_count = max((thread_count + local_size - 1) // local_size, 1) - - ensure_pixel_buffer(width, height) - - if pending_scroll: - scroll_adjust = pending_scroll * 0.12 - log_scale = max(min(log_scale - scroll_adjust, 0.5), -4.5) - scale = pow(10.0, log_scale) + while win.isOpen(): + now = time.perf_counter() + dt = max(now - prev_time, 1e-6) + prev_time = now + fps = fps * 0.9 + (1.0 / dt) * 0.1 + + width, height = win.size + width = max(1, int(width)) + height = max(1, int(height)) + thread_count = max(1, width * height) + group_count = max((thread_count + local_size - 1) // local_size, 1) + + ensure_pixel_buffer(width, height) + + if pending_scroll: + scroll_adjust = pending_scroll * 0.12 + log_scale = max(min(log_scale - scroll_adjust, 0.5), -4.5) + scale = pow(10.0, log_scale) + pending_scroll = 0.0 + + aspect = height / float(width) + + visible, _ = win.imgui_begin("Mandelbrot Controls", open=True) + if visible: + win.imgui_text(f"Resolution: {width} × {height}") + win.imgui_text(f"Format: {fmt}") + win.imgui_text(f"FPS: {fps:5.1f} | {dt * 1000.0:.2f} ms") + log_scale = win.imgui_slider_float("log₁₀ scale", log_scale, -4.5, 0.5) + scale = pow(10.0, log_scale) + center[0] = win.imgui_slider_float("Center X", center[0], -2.5, 1.5) + center[1] = win.imgui_slider_float("Center Y", center[1], -1.5, 1.5) + auto_iterations = win.imgui_checkbox("Auto iterations", auto_iterations) + if auto_iterations: + auto_value = max(64, int(200 + (-math.log10(scale)) * 120)) + manual_iterations = auto_value + win.imgui_text(f"Max iterations (auto): {auto_value}") + else: + manual_iterations = win.imgui_slider_int("Max iterations", manual_iterations, 64, 5000) + swap_rb = win.imgui_checkbox("BGRA swap", swap_rb) + if win.imgui_button("Reset view"): + center[:] = (-0.5, 0.0) + scale = 3.0 + log_scale = math.log10(scale) + manual_iterations = 500 + auto_iterations = True pending_scroll = 0.0 - aspect = height / float(width) - - visible, _ = win.imgui_begin("Mandelbrot Controls", open=True) - if visible: - win.imgui_text(f"Resolution: {width} × {height}") - win.imgui_text(f"Format: {fmt}") - win.imgui_text(f"FPS: {fps:5.1f} | {dt * 1000.0:.2f} ms") - log_scale = win.imgui_slider_float("log₁₀ scale", log_scale, -4.5, 0.5) - scale = pow(10.0, log_scale) - center[0] = win.imgui_slider_float("Center X", center[0], -2.5, 1.5) - center[1] = win.imgui_slider_float("Center Y", center[1], -1.5, 1.5) - auto_iterations = win.imgui_checkbox("Auto iterations", auto_iterations) - if auto_iterations: - auto_value = max(64, int(200 + (-math.log10(scale)) * 120)) - manual_iterations = auto_value - win.imgui_text(f"Max iterations (auto): {auto_value}") - else: - manual_iterations = win.imgui_slider_int("Max iterations", manual_iterations, 64, 5000) - swap_rb = win.imgui_checkbox("BGRA swap", swap_rb) - if win.imgui_button("Reset view"): - center[:] = (-0.5, 0.0) - scale = 3.0 - log_scale = math.log10(scale) - manual_iterations = 500 - auto_iterations = True - pending_scroll = 0.0 - - plot_history[history_index % plot_history.size] = float(manual_iterations) - history_index += 1 - win.imgui_plot_lines( - "Iteration history", - plot_history, - values_offset=history_index % plot_history.size, - graph_size=(0.0, 60.0), - ) - win.imgui_end() - - xspan = scale - yspan = xspan * aspect - dx = xspan / width - dy = yspan / height - - mouse_pos = win.mouse_position() - want_capture_mouse = win.imgui_want_capture_mouse() - mouse_down = win.is_mouse_button_pressed(0) - dragging_now = mouse_down and not want_capture_mouse - if dragging and dragging_now: - delta_x = mouse_pos[0] - prev_mouse_pos[0] - delta_y = mouse_pos[1] - prev_mouse_pos[1] - center[0] -= delta_x * dx - center[1] -= delta_y * dy - dragging = dragging_now - prev_mouse_pos = mouse_pos - - xmin = center[0] - xspan * 0.5 - ymin = center[1] - yspan * 0.5 - params[0] = float(width) - params[1] = float(height) - params[2] = xmin - params[3] = ymin - params[4] = dx - params[5] = dy - params[6] = float(manual_iterations) - params[7] = 1.0 if swap_rb else 0.0 - - params_buffer.setData(params) - program.run([params_buffer], [pixel_buffer], group_count) - - win.drawBuffer(pixel_buffer, width, height) - - _, scroll_dy = win.consume_scroll_delta() - if not want_capture_mouse: - pending_scroll += scroll_dy - finally: - win.close() - pixel_buffer = None - params_buffer = None + plot_history[history_index % plot_history.size] = float(manual_iterations) + history_index += 1 + win.imgui_plot_lines( + "Iteration history", + plot_history, + values_offset=history_index % plot_history.size, + graph_size=(0.0, 60.0), + ) + win.imgui_end() + + xspan = scale + yspan = xspan * aspect + dx = xspan / width + dy = yspan / height + + mouse_pos = win.mouse_position() + want_capture_mouse = win.imgui_want_capture_mouse() + mouse_down = win.is_mouse_button_pressed(0) + dragging_now = mouse_down and not want_capture_mouse + if dragging and dragging_now: + delta_x = mouse_pos[0] - prev_mouse_pos[0] + delta_y = mouse_pos[1] - prev_mouse_pos[1] + center[0] -= delta_x * dx + center[1] -= delta_y * dy + dragging = dragging_now + prev_mouse_pos = mouse_pos + + xmin = center[0] - xspan * 0.5 + ymin = center[1] - yspan * 0.5 + params[0] = float(width) + params[1] = float(height) + params[2] = xmin + params[3] = ymin + params[4] = dx + params[5] = dy + params[6] = float(manual_iterations) + params[7] = 1.0 if swap_rb else 0.0 + + params_buffer.setData(params) + program.run([params_buffer], [pixel_buffer], group_count) + + win.drawBuffer(pixel_buffer, width, height) + + _, scroll_dy = win.consume_scroll_delta() + if not want_capture_mouse: + pending_scroll += scroll_dy if __name__ == "__main__": diff --git a/examples/debug.py b/examples/debug.py index 5ffab1a5..f09492df 100644 --- a/examples/debug.py +++ b/examples/debug.py @@ -73,12 +73,11 @@ def main(): p = np.array([float(W), float(H), xmin, ymin, dx, dy, max_iter, 1.0 if is_bgra else 0.0], dtype=np.float32) params.setData(p) - try: - while win.isOpen(): - prog.run([params], [pix], group_count) - win.drawBuffer(pix, W, H) - finally: - win.close() + while win.isOpen(): + prog.run([params], [pix], group_count) + win.drawBuffer(pix, W, H) + + win.close() if __name__ == "__main__": main() \ No newline at end of file diff --git a/examples/imgui_showcase.py b/examples/imgui_showcase.py index c192eb05..d3f52222 100644 --- a/examples/imgui_showcase.py +++ b/examples/imgui_showcase.py @@ -73,8 +73,7 @@ def apply_theme(name: str) -> None: window.imgui_scale_all_sizes(2.0) window.imgui_set_font_global_scale(state["font_scale"]) - try: - while window.isOpen(): + while window.isOpen(): now = time.perf_counter() dt = now - last_time last_time = now @@ -263,8 +262,8 @@ def apply_theme(name: str) -> None: ) window.present() - finally: - window.close() + + window.close() if __name__ == "__main__": diff --git a/examples/radix_sort/__main__.py b/examples/radix_sort/__main__.py index 19601839..94260244 100644 --- a/examples/radix_sort/__main__.py +++ b/examples/radix_sort/__main__.py @@ -109,160 +109,152 @@ def main() -> None: sort_count = 0 total_kernel_time = 0.0 - sorter = None - window = None - try: - sorter = HistogramRadixSort(bits_per_pass=bits_per_pass) - - window_title = f"TensorFrost Radix Sort ({count:,} elements)" - window = tf.createWindow(window_width, window_height, window_title) - - if font_scale > 0.0 and window is not None: - window.imgui_scale_all_sizes(font_scale) - window.imgui_set_font_global_scale(font_scale) - - start_time = time.perf_counter() - last_frame_time = start_time - - while window is not None and window.isOpen(): - now = time.perf_counter() - dt = now - last_frame_time - last_frame_time = now - - if dt > 0.0: - frame_times.append(dt) - frame_time_total = sum(frame_times) - fps = (len(frame_times) / frame_time_total) if frame_times and frame_time_total > 0.0 else 0.0 - - # Perform sort; on the first run also validate on GPU and avoid full array readback. - do_validate = not validated - _keys_out, _vals_out = sorter.sort( - keys, - values, - collect_stage_timings=True, - validate=do_validate, - return_arrays=False, + sorter = HistogramRadixSort(bits_per_pass=bits_per_pass) + + window_title = f"TensorFrost Radix Sort ({count:,} elements)" + window = tf.createWindow(window_width, window_height, window_title) + + if font_scale > 0.0 and window is not None: + window.imgui_scale_all_sizes(font_scale) + window.imgui_set_font_global_scale(font_scale) + + start_time = time.perf_counter() + last_frame_time = start_time + + while window is not None and window.isOpen(): + now = time.perf_counter() + dt = now - last_frame_time + last_frame_time = now + + if dt > 0.0: + frame_times.append(dt) + frame_time_total = sum(frame_times) + fps = (len(frame_times) / frame_time_total) if frame_times and frame_time_total > 0.0 else 0.0 + + # Perform sort; on the first run also validate on GPU and avoid full array readback. + do_validate = not validated + _keys_out, _vals_out = sorter.sort( + keys, + values, + collect_stage_timings=True, + validate=do_validate, + return_arrays=False, + ) + stage_timings = sorter.last_stage_timings or {} + # Separate total_pass (overall) from per-stage summed time + total_pass_time = float(stage_timings.get("total_pass", 0.0)) + stage_sum_time = float(sum(v for k, v in stage_timings.items() if k != "total_pass")) + kernel_time = total_pass_time if total_pass_time > 0.0 else stage_sum_time + + if sort_times.maxlen == len(sort_times): + sort_times.popleft() + sort_times.append(kernel_time) + + sort_count += 1 + total_kernel_time += kernel_time + for name in _STAGE_NAMES: + stage_totals_overall[name] += stage_timings.get(name, 0.0) + + if do_validate: + errors = int(getattr(sorter, "last_validation_errors", 0) or 0) + validation_ok = (errors == 0) + validation_message = ( + "GPU validation passed (sorted)." if validation_ok else f"GPU validation failed: {errors} out-of-order pairs" ) - stage_timings = sorter.last_stage_timings or {} - # Separate total_pass (overall) from per-stage summed time - total_pass_time = float(stage_timings.get("total_pass", 0.0)) - stage_sum_time = float(sum(v for k, v in stage_timings.items() if k != "total_pass")) - kernel_time = total_pass_time if total_pass_time > 0.0 else stage_sum_time - - if sort_times.maxlen == len(sort_times): - sort_times.popleft() - sort_times.append(kernel_time) - - sort_count += 1 - total_kernel_time += kernel_time - for name in _STAGE_NAMES: - stage_totals_overall[name] += stage_timings.get(name, 0.0) - - if do_validate: - errors = int(getattr(sorter, "last_validation_errors", 0) or 0) - validation_ok = (errors == 0) - validation_message = ( - "GPU validation passed (sorted)." if validation_ok else f"GPU validation failed: {errors} out-of-order pairs" - ) - validated = True - - # No large array readback performed when return_arrays=False. - - window_avg_sort = (sum(sort_times) / len(sort_times)) if sort_times else 0.0 - last_sort_ms = kernel_time * 1000.0 - avg_sort_ms = window_avg_sort * 1000.0 - - last_sort_rate = (1.0 / kernel_time) if kernel_time > 0.0 else 0.0 - window_sort_rate = (1.0 / window_avg_sort) if window_avg_sort > 0.0 else 0.0 - overall_avg_sort = (total_kernel_time / sort_count) if sort_count else 0.0 - overall_sort_rate = (1.0 / overall_avg_sort) if overall_avg_sort > 0.0 else 0.0 - - last_elements_per_sec = (count / kernel_time) if count and kernel_time > 0.0 else 0.0 - window_elements_per_sec = (count / window_avg_sort) if count and window_avg_sort > 0.0 else 0.0 - - history_data = np.array(sort_times, dtype=np.float32) * 1000.0 - total_elapsed = now - start_time - - visible, _ = window.imgui_begin("Radix Sort Performance", open=None, flags=0) - if visible: - window.imgui_text(f"Elements: {count:,}") - window.imgui_text(f"Bits per pass: {bits_per_pass}") - window.imgui_separator() - window.imgui_text(f"Frame time: {dt * 1000.0:7.3f} ms") - window.imgui_text(f"Average FPS: {fps:6.1f}") - window.imgui_separator() - - if sort_count: - window.imgui_text(f"Last kernel time: {last_sort_ms:7.3f} ms") - if total_pass_time > 0.0: - window.imgui_text(f"Total pass time (last): {total_pass_time * 1000.0:7.3f} ms") - if window_avg_sort > 0.0: - window.imgui_text(f"Window avg kernel: {avg_sort_ms:7.3f} ms") - if overall_avg_sort > 0.0: - window.imgui_text(f"Overall avg kernel: {overall_avg_sort * 1000.0:7.3f} ms") - + validated = True + + # No large array readback performed when return_arrays=False. + + window_avg_sort = (sum(sort_times) / len(sort_times)) if sort_times else 0.0 + last_sort_ms = kernel_time * 1000.0 + avg_sort_ms = window_avg_sort * 1000.0 + + last_sort_rate = (1.0 / kernel_time) if kernel_time > 0.0 else 0.0 + window_sort_rate = (1.0 / window_avg_sort) if window_avg_sort > 0.0 else 0.0 + overall_avg_sort = (total_kernel_time / sort_count) if sort_count else 0.0 + overall_sort_rate = (1.0 / overall_avg_sort) if overall_avg_sort > 0.0 else 0.0 + + last_elements_per_sec = (count / kernel_time) if count and kernel_time > 0.0 else 0.0 + window_elements_per_sec = (count / window_avg_sort) if count and window_avg_sort > 0.0 else 0.0 + + history_data = np.array(sort_times, dtype=np.float32) * 1000.0 + total_elapsed = now - start_time + + visible, _ = window.imgui_begin("Radix Sort Performance", open=None, flags=0) + if visible: + window.imgui_text(f"Elements: {count:,}") + window.imgui_text(f"Bits per pass: {bits_per_pass}") + window.imgui_separator() + window.imgui_text(f"Frame time: {dt * 1000.0:7.3f} ms") + window.imgui_text(f"Average FPS: {fps:6.1f}") + window.imgui_separator() + + if sort_count: + window.imgui_text(f"Last kernel time: {last_sort_ms:7.3f} ms") + if total_pass_time > 0.0: + window.imgui_text(f"Total pass time (last): {total_pass_time * 1000.0:7.3f} ms") + if window_avg_sort > 0.0: + window.imgui_text(f"Window avg kernel: {avg_sort_ms:7.3f} ms") + if overall_avg_sort > 0.0: + window.imgui_text(f"Overall avg kernel: {overall_avg_sort * 1000.0:7.3f} ms") + + window.imgui_spacing() + window.imgui_text(f"Sorts/sec (last): {last_sort_rate:8.2f}") + window.imgui_text(f"Sorts/sec (window): {window_sort_rate:8.2f}") + window.imgui_text(f"Sorts/sec (overall): {overall_sort_rate:8.2f}") + + if count: window.imgui_spacing() - window.imgui_text(f"Sorts/sec (last): {last_sort_rate:8.2f}") - window.imgui_text(f"Sorts/sec (window): {window_sort_rate:8.2f}") - window.imgui_text(f"Sorts/sec (overall): {overall_sort_rate:8.2f}") - - if count: - window.imgui_spacing() - window.imgui_text( - f"Elements/sec (last): {_format_rate(last_elements_per_sec, ' elements/s')}" - ) - window.imgui_text( - f"Elements/sec (window): {_format_rate(window_elements_per_sec, ' elements/s')}" - ) + window.imgui_text( + f"Elements/sec (last): {_format_rate(last_elements_per_sec, ' elements/s')}" + ) + window.imgui_text( + f"Elements/sec (window): {_format_rate(window_elements_per_sec, ' elements/s')}" + ) - window.imgui_spacing() - window.imgui_text(f"Total sorts: {sort_count:,}") + window.imgui_spacing() + window.imgui_text(f"Total sorts: {sort_count:,}") + else: + if count: + window.imgui_text("Waiting for first GPU sort...") else: - if count: - window.imgui_text("Waiting for first GPU sort...") - else: - window.imgui_text("No elements to sort.") - - if validated: - color = (0.20, 0.80, 0.40, 1.0) if validation_ok else (0.92, 0.35, 0.32, 1.0) - window.imgui_text_colored(color, validation_message) - elif count: - window.imgui_text("Validating against CPU reference...") - - if history_data.size: - max_plot = float(np.max(history_data)) if history_data.size else 1.0 - max_plot = max(1.0, max_plot * 1.1) - window.imgui_plot_lines( - "Kernel time history (ms)", - history_data, - values_offset=0, - overlay_text=f"{history_data.size} sample history", - scale_min=0.0, - scale_max=max_plot, - graph_size=(0.0, 140.0), - stride=4, - ) + window.imgui_text("No elements to sort.") + + if validated: + color = (0.20, 0.80, 0.40, 1.0) if validation_ok else (0.92, 0.35, 0.32, 1.0) + window.imgui_text_colored(color, validation_message) + elif count: + window.imgui_text("Validating against CPU reference...") + + if history_data.size: + max_plot = float(np.max(history_data)) if history_data.size else 1.0 + max_plot = max(1.0, max_plot * 1.1) + window.imgui_plot_lines( + "Kernel time history (ms)", + history_data, + values_offset=0, + overlay_text=f"{history_data.size} sample history", + scale_min=0.0, + scale_max=max_plot, + graph_size=(0.0, 140.0), + stride=4, + ) - if stage_timings: - window.imgui_separator() - window.imgui_text("Kernel stages (last run):") - for name, duration in sorted(stage_timings.items(), key=lambda item: item[1], reverse=True): - window.imgui_text(f"{_format_stage_name(name)}: {duration * 1000.0:7.3f} ms") + if stage_timings: + window.imgui_separator() + window.imgui_text("Kernel stages (last run):") + for name, duration in sorted(stage_timings.items(), key=lambda item: item[1], reverse=True): + window.imgui_text(f"{_format_stage_name(name)}: {duration * 1000.0:7.3f} ms") - window.imgui_spacing() - window.imgui_text("Kernel stages (overall avg):") - for name in sorted(stage_totals_overall.keys()): - avg_duration = (stage_totals_overall[name] / sort_count) if sort_count else 0.0 - window.imgui_text(f"{_format_stage_name(name)}: {avg_duration * 1000.0:7.3f} ms") + window.imgui_spacing() + window.imgui_text("Kernel stages (overall avg):") + for name in sorted(stage_totals_overall.keys()): + avg_duration = (stage_totals_overall[name] / sort_count) if sort_count else 0.0 + window.imgui_text(f"{_format_stage_name(name)}: {avg_duration * 1000.0:7.3f} ms") - window.imgui_end() - window.present() - finally: - if window is not None: - window.close() - if sorter is not None: - sorter.close() + window.imgui_end() + window.present() if __name__ == "__main__": - main() + main() diff --git a/examples/radix_sort/sort.py b/examples/radix_sort/sort.py index fe0104e6..49aeff10 100644 --- a/examples/radix_sort/sort.py +++ b/examples/radix_sort/sort.py @@ -177,24 +177,6 @@ def inject_defines(filename: str, *, with_group: bool = False, with_histogram: b self._dummy_values_buffer = tf.createBuffer(1, 4, False) - def close(self) -> None: - for program in ( - self._map_to_uint_program, - self._map_from_uint_program, - self._histogram_program, - self._unpack_program, - self._prefix_local_program, - self._prefix_blocks_program, - self._prefix_accum_program, - self._bucket_scan_program, - self._scatter_program, - self._validate_program, - ): - if program is not None: - program.release() - if self._dummy_values_buffer is not None: - self._dummy_values_buffer.release() - def sort( self, keys: np.ndarray, diff --git a/tests/imgui_test.py b/tests/imgui_test.py index 371db5c2..11b33d2d 100644 --- a/tests/imgui_test.py +++ b/tests/imgui_test.py @@ -1,5 +1,4 @@ import unittest -from contextlib import contextmanager import numpy as np import TensorFrost as tf @@ -17,21 +16,34 @@ def _should_skip_for_backend(exc: RuntimeError) -> bool: return any(token in message for token in keywords) -@contextmanager -def managed_window(width=320, height=240, title="ImGui Test Window"): - try: - win = tf.createWindow(width, height, title) - except RuntimeError as exc: - if _should_skip_for_backend(exc): - raise unittest.SkipTest(f"Window backend unavailable: {exc}") from exc - raise - try: - yield win - finally: +class _ManagedWindow: + def __init__(self, width=320, height=240, title="ImGui Test Window"): + self._width = width + self._height = height + self._title = title + self._window = None + + def __enter__(self): try: - win.close() - except Exception: - pass + self._window = tf.createWindow(self._width, self._height, self._title) + except RuntimeError as exc: + if _should_skip_for_backend(exc): + raise unittest.SkipTest(f"Window backend unavailable: {exc}") from exc + raise + return self._window + + def __exit__(self, exc_type, exc, tb): + if self._window is not None: + try: + self._window.close() + except Exception: + pass + self._window = None + return False + + +def managed_window(width=320, height=240, title="ImGui Test Window"): + return _ManagedWindow(width, height, title) class ImGuiIntegrationTest(unittest.TestCase): diff --git a/tests/vulkan_window_test.py b/tests/vulkan_window_test.py index 9f871fd5..ebdb8f80 100644 --- a/tests/vulkan_window_test.py +++ b/tests/vulkan_window_test.py @@ -36,25 +36,24 @@ def test_compute_dispatch_and_window_present(self): self.skipTest(f"Vulkan program creation failed: {exc}") window = None + program.run([], [pixel_buffer], group_count) + pixels = pixel_buffer.getData(np.dtype(np.uint32), thread_count) + self.assertTrue(np.all(pixels == 0xFF3366FF), "Compute shader did not write expected color") + try: - program.run([], [pixel_buffer], group_count) - pixels = pixel_buffer.getData(np.dtype(np.uint32), thread_count) - self.assertTrue(np.all(pixels == 0xFF3366FF), "Compute shader did not write expected color") - - try: - window = tf.createWindow(width, height, "TensorFrost Vulkan Test") - except RuntimeError as exc: # pragma: no cover - Vulkan window not available - self.skipTest(f"Vulkan window creation failed: {exc}") - - # Present once to ensure the binding path is exercised. - window.drawBuffer(pixel_buffer, width, height) - self.assertTrue(window.isOpen(), "Window should report as open after initial present") - finally: - if window is not None: - window.close() - window = None - program = None - pixel_buffer = None + window = tf.createWindow(width, height, "TensorFrost Vulkan Test") + except RuntimeError as exc: # pragma: no cover - Vulkan window not available + self.skipTest(f"Vulkan window creation failed: {exc}") + + # Present once to ensure the binding path is exercised. + window.drawBuffer(pixel_buffer, width, height) + self.assertTrue(window.isOpen(), "Window should report as open after initial present") + + if window is not None: + window.close() + window = None + program = None + pixel_buffer = None # Ensure GPU resources are released before interpreter shutdown. import gc From 036e79bbb761669a4d22215a4fb88c4bfd906c67 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Sun, 9 Nov 2025 04:27:09 +0100 Subject: [PATCH 42/44] Use push constants everywhere --- CMakeLists.txt | 4 +- Python/TensorFrost/sort.py | 54 ++++----- TensorFrost/Backend/CMakeLists.txt | 1 - TensorFrost/Backend/include/Backend/Vulkan.h | 8 +- TensorFrost/Backend/src/Vulkan.cpp | 98 ++++++++++------ .../src/Definitions/VulkanBindings.cpp | 17 ++- .../src/Definitions/VulkanInterface.cpp | 53 +++++++-- TensorFrost/src/Definitions/VulkanInterface.h | 8 +- examples/Slang/mandelbrot.py | 6 +- examples/Slang/mandelbrot.slang | 32 ++++-- examples/debug.py | 105 ++++++------------ examples/radix_sort/shaders/bucket_scan.slang | 23 +++- examples/radix_sort/shaders/histogram.slang | 29 +++-- .../radix_sort/shaders/map_from_uint.slang | 17 ++- examples/radix_sort/shaders/map_to_uint.slang | 17 ++- .../radix_sort/shaders/prefix_accum.slang | 27 +++-- .../radix_sort/shaders/prefix_block.slang | 23 +++- .../radix_sort/shaders/prefix_local.slang | 29 +++-- examples/radix_sort/shaders/scatter.slang | 37 ++++-- examples/radix_sort/shaders/unpack.slang | 23 +++- .../radix_sort/shaders/validate_sorted.slang | 19 +++- examples/radix_sort/sort.py | 64 +++++------ tests/vulkan_window_test.py | 21 ++-- 23 files changed, 440 insertions(+), 275 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 60c88f13..09010907 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,11 +34,11 @@ set(GLFW_BUILD_TESTS OFF CACHE BOOL "" FORCE) set(GLFW_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE) set(PYBIND11_FINDPYTHON ON) -find_package(Vulkan REQUIRED COMPONENTS shaderc_combined) +find_package(Vulkan REQUIRED) add_subdirectory(external/pybind11) add_subdirectory(external/glfw) add_subdirectory(TensorFrost) add_subdirectory(examples) -set_property(DIRECTORY ${CMAKE_SOURCE_DIR} PROPERTY VS_STARTUP_PROJECT TensorFrost) \ No newline at end of file +set_property(DIRECTORY ${CMAKE_SOURCE_DIR} PROPERTY VS_STARTUP_PROJECT TensorFrost) diff --git a/Python/TensorFrost/sort.py b/Python/TensorFrost/sort.py index 5148eca9..ef5e66e5 100644 --- a/Python/TensorFrost/sort.py +++ b/Python/TensorFrost/sort.py @@ -90,14 +90,14 @@ def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: "radix_map_to_uint", _load_shader_source("map_to_uint.slang"), "csMapToUint", - ro_count=2, + ro_count=1, rw_count=1, ) self._map_from_uint_program = tf.createComputeProgramFromSlang( "radix_map_from_uint", _load_shader_source("map_from_uint.slang"), "csMapFromUint", - ro_count=2, + ro_count=1, rw_count=1, ) @@ -105,42 +105,42 @@ def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: "radix_histogram", _load_shader_source("histogram.slang"), "csHistogram", - ro_count=2, + ro_count=1, rw_count=1, ) self._unpack_program = tf.createComputeProgramFromSlang( "radix_unpack", _load_shader_source("unpack.slang"), "csUnpack", - ro_count=2, + ro_count=1, rw_count=1, ) self._prefix_local_program = tf.createComputeProgramFromSlang( "radix_prefix_local", _load_shader_source("prefix_local.slang"), "csPrefixLocal", - ro_count=2, + ro_count=1, rw_count=2, ) self._prefix_blocks_program = tf.createComputeProgramFromSlang( "radix_prefix_blocks", _load_shader_source("prefix_block.slang"), "csPrefixBlocks", - ro_count=2, + ro_count=1, rw_count=1, ) self._prefix_accum_program = tf.createComputeProgramFromSlang( "radix_prefix_accum", _load_shader_source("prefix_accum.slang"), "csPrefixAccumulate", - ro_count=2, + ro_count=1, rw_count=1, ) self._bucket_scan_program = tf.createComputeProgramFromSlang( "radix_bucket_scan", _load_shader_source("bucket_scan.slang"), "csBucketScan", - ro_count=2, + ro_count=1, rw_count=1, ) scatter_source = f"#define TF_HISTOGRAM_SIZE {self.histogram_size}u\n" + _load_shader_source("scatter.slang") @@ -148,7 +148,7 @@ def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: "radix_scatter", scatter_source, "csScatter", - ro_count=5, + ro_count=4, rw_count=2, ) @@ -211,16 +211,10 @@ def sort( params_array[6] = np.uint32(block_count) params_array[7] = np.uint32(1 if values_array is not None else 0) - map_params = np.zeros(4, dtype=np.uint32) + map_params = np.zeros(2, dtype=np.uint32) map_params[0] = np.uint32(element_count) map_params[1] = _TYPE_CODES[key_kind] - params_buffer = tf.createBuffer(params_array.size, 4, True) - params_buffer.setData(params_array) - - map_buffer = tf.createBuffer(map_params.size, 4, True) - map_buffer.setData(map_params) - key_buffers = [tf.createBuffer(max(element_count, 1), 4, False) for _ in range(2)] key_buffers[0].setData(keys_array) @@ -249,9 +243,10 @@ def sort( histogram_groups = num_groups self._map_to_uint_program.run( - [map_buffer, key_buffers[0]], + [key_buffers[0]], [key_buffers[1]], map_groups, + map_params, ) key_in = key_buffers[1] @@ -260,48 +255,54 @@ def sort( for pass_index in range(passes): params_array[2] = np.uint32(pass_index * self.bits_per_pass) - params_buffer.setData(params_array) self._histogram_program.run( - [params_buffer, key_in], + [key_in], [packed_hist_buffer], histogram_groups, + params_array, ) self._unpack_program.run( - [params_buffer, packed_hist_buffer], + [packed_hist_buffer], [group_hist_buffer], unpack_groups, + params_array, ) self._prefix_local_program.run( - [params_buffer, group_hist_buffer], + [group_hist_buffer], [prefix_buffer, block_totals_buffer], prefix_local_groups, + params_array, ) self._prefix_blocks_program.run( - [params_buffer, block_totals_buffer], + [block_totals_buffer], [block_prefix_buffer], prefix_block_groups, + params_array, ) self._prefix_accum_program.run( - [params_buffer, block_prefix_buffer], + [block_prefix_buffer], [prefix_buffer], prefix_accum_groups, + params_array, ) self._bucket_scan_program.run( - [params_buffer, prefix_buffer], + [prefix_buffer], [bucket_scan_buffer], bucket_scan_groups, + params_array, ) self._scatter_program.run( - [params_buffer, key_in, val_in, prefix_buffer, bucket_scan_buffer], + [key_in, val_in, prefix_buffer, bucket_scan_buffer], [key_out, val_out], scatter_groups, + params_array, ) key_in, key_out = key_out, key_in @@ -309,9 +310,10 @@ def sort( val_in, val_out = val_out, val_in self._map_from_uint_program.run( - [map_buffer, key_in], + [key_in], [key_out], map_groups, + map_params, ) sorted_keys = key_out.getData(key_dtype, element_count) diff --git a/TensorFrost/Backend/CMakeLists.txt b/TensorFrost/Backend/CMakeLists.txt index 58a93ef9..3289f00e 100644 --- a/TensorFrost/Backend/CMakeLists.txt +++ b/TensorFrost/Backend/CMakeLists.txt @@ -23,7 +23,6 @@ target_include_directories(TensorFrostBackend target_link_libraries(TensorFrostBackend PUBLIC Vulkan::Vulkan - Vulkan::shaderc_combined glfw $<$:${SLANG_LIB_DEBUG}> $<$>:${SLANG_LIB_RELEASE}>) diff --git a/TensorFrost/Backend/include/Backend/Vulkan.h b/TensorFrost/Backend/include/Backend/Vulkan.h index fc8e2715..071a78f1 100644 --- a/TensorFrost/Backend/include/Backend/Vulkan.h +++ b/TensorFrost/Backend/include/Backend/Vulkan.h @@ -18,6 +18,7 @@ struct ComputeProgram { vk::PipelineLayout pipelineLayout; vk::Pipeline pipeline; uint32_t numRO = 0, numRW = 0; + uint32_t pushConstantSize = 0; }; struct ComputeBindings { @@ -88,7 +89,6 @@ void destroyBuffer(Buffer& buf); void setBufferData(Buffer& buf, const void* src, size_t bytes, size_t offset = 0); void getBufferData(const Buffer& buf, void* dst, size_t bytes, size_t offset = 0); -ComputeProgram createComputeProgramFromGLSL(const std::string& glsl, uint32_t roCount, uint32_t rwCount); ComputeProgram createComputeProgramFromSlang(const std::string& moduleName, const std::string& source, const std::string& entry, uint32_t roCount, uint32_t rwCount); void destroyComputeProgram(ComputeProgram& prog); @@ -96,6 +96,8 @@ void destroyComputeProgram(ComputeProgram& prog); void runProgram(const ComputeProgram& prog, const std::vector& readonlyBuffers, const std::vector& readwriteBuffers, - uint32_t groupCount); + uint32_t groupCount, + const void* pushConstants, + size_t pushConstantSize); -VulkanContext& getVulkanContext(); \ No newline at end of file +VulkanContext& getVulkanContext(); diff --git a/TensorFrost/Backend/src/Vulkan.cpp b/TensorFrost/Backend/src/Vulkan.cpp index c9cce616..a36534df 100644 --- a/TensorFrost/Backend/src/Vulkan.cpp +++ b/TensorFrost/Backend/src/Vulkan.cpp @@ -1,9 +1,10 @@ #include "Backend/Vulkan.h" VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE -#include #include #include #include +#include +#include #include namespace { @@ -444,28 +445,15 @@ VulkanContext::~VulkanContext() { instance.destroy(); } -// compile GLSL to SPIR-V at runtime -static std::vector compileGLSLToSpirv(const std::string& source) { - shaderc::Compiler compiler; - shaderc::CompileOptions opts; - opts.SetTargetEnvironment(shaderc_target_env_vulkan, - shaderc_env_version_vulkan_1_1); -#if defined(_RELWITHDEBINFO) - opts.SetGenerateDebugInfo(); - opts.SetOptimizationLevel(shaderc_optimization_level_zero); -#endif - shaderc::SpvCompilationResult result = - compiler.CompileGlslToSpv(source, shaderc_compute_shader, "shader", opts); - if (result.GetCompilationStatus() != shaderc_compilation_status_success) { - throw std::runtime_error(result.GetErrorMessage()); - } - return {result.cbegin(), result.cend()}; -} +struct SlangCompileResult { + std::vector spirv; + uint32_t pushConstantSize = 0; +}; -std::vector compileSlangToSpirv(const char* moduleName, - const char* source, - const char* entry, - const char* profile /* e.g., "spirv_1_5" */) { +SlangCompileResult compileSlangToSpirv(const char* moduleName, + const char* source, + const char* entry, + const char* profile /* e.g., "spirv_1_5" */) { Slang::ComPtr global; createGlobalSession(global.writeRef()); @@ -524,11 +512,32 @@ std::vector compileSlangToSpirv(const char* moduleName, if (SLANG_FAILED(r)) throw std::runtime_error("slang: getEntryPointCode failed"); } + uint32_t pushConstantSize = 0; + { + Slang::ComPtr diag; + slang::ProgramLayout* layout = linked->getLayout(0, diag.writeRef()); + if (diag && diag->getBufferSize()) std::fprintf(stderr, "%s\n", (const char*)diag->getBufferPointer()); + if (!layout) throw std::runtime_error("slang: failed to obtain program layout"); + + if (auto* globalLayout = layout->getGlobalParamsTypeLayout()) { + size_t size = globalLayout->getSize(slang::ParameterCategory::PushConstantBuffer); + pushConstantSize = std::max(pushConstantSize, static_cast(size)); + } + for (SlangUInt i = 0; i < layout->getEntryPointCount(); ++i) { + if (auto* entry = layout->getEntryPointByIndex(i)) { + if (auto* typeLayout = entry->getTypeLayout()) { + size_t size = typeLayout->getSize(slang::ParameterCategory::PushConstantBuffer); + pushConstantSize = std::max(pushConstantSize, static_cast(size)); + } + } + } + } + size_t n = spirv->getBufferSize(); auto* p = static_cast(spirv->getBufferPointer()); std::vector out((n + 3) / 4); std::memcpy(out.data(), p, n); - return out; + return {std::move(out), pushConstantSize}; } ComputeBindings createBindings(VulkanContext& ctx, const ComputeProgram& prog, @@ -556,11 +565,19 @@ ComputeBindings createBindings(VulkanContext& ctx, const ComputeProgram& prog, } static ComputeProgram createComputeProgram(const std::vector& spirv, - uint32_t roCount, uint32_t rwCount) { + uint32_t roCount, uint32_t rwCount, uint32_t pushConstantSize) { auto& ctx = getVulkanContext(); ComputeProgram prog; prog.numRO = roCount; prog.numRW = rwCount; + prog.pushConstantSize = pushConstantSize; + + if (prog.pushConstantSize) { + auto limits = ctx.physicalDevice.getProperties().limits; + if (prog.pushConstantSize > limits.maxPushConstantsSize) { + throw std::runtime_error("push constant block exceeds device limit"); + } + } vk::ShaderModuleCreateInfo smci({}, spirv.size() * sizeof(uint32_t), spirv.data()); prog.shaderModule = ctx.device.createShaderModule(smci); @@ -572,7 +589,12 @@ static ComputeProgram createComputeProgram(const std::vector& spirv, vk::DescriptorSetLayoutCreateInfo dsInfo({}, bindings.size(), bindings.data()); prog.descriptorLayout = ctx.device.createDescriptorSetLayout(dsInfo); - vk::PipelineLayoutCreateInfo plInfo({}, 1, &prog.descriptorLayout); + + vk::PushConstantRange pushRange(vk::ShaderStageFlagBits::eCompute, 0, prog.pushConstantSize); + auto pushPtr = prog.pushConstantSize ? &pushRange : nullptr; + uint32_t pushCount = prog.pushConstantSize ? 1u : 0u; + + vk::PipelineLayoutCreateInfo plInfo({}, 1, &prog.descriptorLayout, pushCount, pushPtr); prog.pipelineLayout = ctx.device.createPipelineLayout(plInfo); vk::PipelineShaderStageCreateInfo stageInfo({}, vk::ShaderStageFlagBits::eCompute, prog.shaderModule, "main"); @@ -582,14 +604,10 @@ static ComputeProgram createComputeProgram(const std::vector& spirv, return prog; } -ComputeProgram createComputeProgramFromGLSL(const std::string& glsl, uint32_t roCount, uint32_t rwCount) { - auto spirv = compileGLSLToSpirv(glsl); - return createComputeProgram(spirv, roCount, rwCount); -} ComputeProgram createComputeProgramFromSlang(const std::string& moduleName, const std::string& source, const std::string& entry, uint32_t roCount, uint32_t rwCount) { - auto spirv = compileSlangToSpirv(moduleName.c_str(), source.c_str(), entry.c_str(), "spirv_1_5"); - return createComputeProgram(spirv, roCount, rwCount); + auto result = compileSlangToSpirv(moduleName.c_str(), source.c_str(), entry.c_str(), "spirv_1_5"); + return createComputeProgram(result.spirv, roCount, rwCount, result.pushConstantSize); } void destroyComputeProgram(ComputeProgram& prog) { @@ -605,7 +623,9 @@ void destroyComputeProgram(ComputeProgram& prog) { void runProgram(const ComputeProgram& prog, const std::vector& readonlyBuffers, const std::vector& readwriteBuffers, - uint32_t groupCount) { + uint32_t groupCount, + const void* pushConstants, + size_t pushConstantSize) { auto& ctx = getVulkanContext(); auto set = getOrCreateSet(ctx, prog, readonlyBuffers, readwriteBuffers); @@ -615,6 +635,20 @@ void runProgram(const ComputeProgram& prog, cmd.begin(vk::CommandBufferBeginInfo{}); cmd.bindPipeline(vk::PipelineBindPoint::eCompute, prog.pipeline); cmd.bindDescriptorSets(vk::PipelineBindPoint::eCompute, prog.pipelineLayout, 0, set, {}); + + if (prog.pushConstantSize) { + if (!pushConstants) { + throw std::runtime_error("push constant payload missing"); + } + if (pushConstantSize != prog.pushConstantSize) { + throw std::runtime_error("push constant payload size mismatch"); + } + cmd.pushConstants(prog.pipelineLayout, vk::ShaderStageFlagBits::eCompute, 0, + prog.pushConstantSize, pushConstants); + } else if (pushConstantSize != 0) { + throw std::runtime_error("push constant payload provided but pipeline has none"); + } + cmd.dispatch(groupCount, 1, 1); cmd.end(); diff --git a/TensorFrost/src/Definitions/VulkanBindings.cpp b/TensorFrost/src/Definitions/VulkanBindings.cpp index f2be53f2..c2f63e58 100644 --- a/TensorFrost/src/Definitions/VulkanBindings.cpp +++ b/TensorFrost/src/Definitions/VulkanBindings.cpp @@ -46,20 +46,17 @@ void VulkanDefinitions(py::module_& m) { "Number of read-only storage buffers expected by the program.") .def_property_readonly("readwrite_count", &PyComputeProgram::readwriteCount, "Number of read-write storage buffers expected by the program.") + .def_property_readonly("push_constant_size", &PyComputeProgram::pushConstantSize, + "Size in bytes of the push-constant block expected by this program (0 if unused).") .def("run", &PyComputeProgram::run, - py::arg("readonly_buffers"), py::arg("readwrite_buffers"), py::arg("group_count"), - "Dispatch the compute pipeline with the provided buffers and workgroup count.") + py::arg("readonly_buffers"), + py::arg("readwrite_buffers"), + py::arg("group_count"), + py::arg("push_constants") = py::none(), + "Dispatch the compute pipeline with the provided buffers, workgroup count, and optional push constants.") .def("release", &PyComputeProgram::release, "Explicitly destroy the underlying Vulkan pipeline and associated resources."); - m.def("createComputeProgramFromGLSL", - [](const std::string& source, uint32_t roCount, uint32_t rwCount) { - return MakeComputeProgramFromGLSL(source, roCount, rwCount); - }, - py::arg("glsl_source"), py::arg("ro_count"), py::arg("rw_count"), - py::return_value_policy::move, - "Compile a compute shader written in GLSL to SPIR-V and wrap it in a :class:`ComputeProgram`."); - m.def("createComputeProgramFromSlang", [](const std::string& moduleName, const std::string& source, const std::string& entry, uint32_t roCount, uint32_t rwCount) { return MakeComputeProgramFromSlang(moduleName, source, entry, roCount, rwCount); diff --git a/TensorFrost/src/Definitions/VulkanInterface.cpp b/TensorFrost/src/Definitions/VulkanInterface.cpp index 1dfe20cc..d22bd12d 100644 --- a/TensorFrost/src/Definitions/VulkanInterface.cpp +++ b/TensorFrost/src/Definitions/VulkanInterface.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -28,6 +29,44 @@ bool isCContiguous(const py::buffer_info& info) { return true; } +struct PushConstantPayload { + const void* data = nullptr; + size_t size = 0; + std::vector storage; +}; + +PushConstantPayload preparePushConstantPayload(const ComputeProgram& program, + const py::object& pushConstants) { + PushConstantPayload payload; + if (program.pushConstantSize == 0) { + if (!pushConstants.is_none()) { + throw std::runtime_error("program does not declare push constants"); + } + return payload; + } + if (pushConstants.is_none()) { + throw std::runtime_error("push constant payload required for this program"); + } + + py::buffer buffer(pushConstants); + auto info = buffer.request(); + if (!isCContiguous(info)) { + throw std::runtime_error("push constant payload must be C-contiguous"); + } + size_t totalBytes = static_cast(info.size) * static_cast(info.itemsize); + if (totalBytes != program.pushConstantSize) { + throw std::runtime_error( + "push constant payload size mismatch (got " + std::to_string(totalBytes) + + ", expected " + std::to_string(program.pushConstantSize) + ")"); + } + + payload.storage.resize(totalBytes); + std::memcpy(payload.storage.data(), info.ptr, totalBytes); + payload.data = payload.storage.data(); + payload.size = totalBytes; + return payload; +} + } // namespace PyBuffer::PyBuffer(size_t count, size_t dtypeSize, bool readOnly) @@ -174,7 +213,8 @@ PyComputeProgram& PyComputeProgram::operator=(PyComputeProgram&& other) noexcept void PyComputeProgram::run(const py::iterable& readonlyBuffers, const py::iterable& readwriteBuffers, - uint32_t groupCount) { + uint32_t groupCount, + const py::object& pushConstants) { ensureValid(); std::vector ro; std::vector rw; @@ -183,8 +223,9 @@ void PyComputeProgram::run(const py::iterable& readonlyBuffers, if (ro.size() != program_.numRO || rw.size() != program_.numRW) { throw std::runtime_error("buffer count does not match program layout"); } + auto payload = preparePushConstantPayload(program_, pushConstants); py::gil_scoped_release release; - runProgram(program_, ro, rw, groupCount); + runProgram(program_, ro, rw, groupCount, payload.data, payload.size); } void PyComputeProgram::release() { @@ -199,6 +240,8 @@ uint32_t PyComputeProgram::readonlyCount() const { return program_.numRO; } uint32_t PyComputeProgram::readwriteCount() const { return program_.numRW; } +uint32_t PyComputeProgram::pushConstantSize() const { return program_.pushConstantSize; } + void PyComputeProgram::ensureValid() const { if (!ctx_ || !program_.pipeline) { throw std::runtime_error("ComputeProgram has been released"); @@ -696,12 +739,6 @@ py::tuple PyWindow::vec4ToTuple(const ImVec4& vec) { return py::make_tuple(vec.x, vec.y, vec.z, vec.w); } -PyComputeProgram MakeComputeProgramFromGLSL(const std::string& source, - uint32_t roCount, - uint32_t rwCount) { - return PyComputeProgram(createComputeProgramFromGLSL(source, roCount, rwCount)); -} - PyComputeProgram MakeComputeProgramFromSlang(const std::string& moduleName, const std::string& source, const std::string& entry, diff --git a/TensorFrost/src/Definitions/VulkanInterface.h b/TensorFrost/src/Definitions/VulkanInterface.h index 4219205c..f2c2ac2a 100644 --- a/TensorFrost/src/Definitions/VulkanInterface.h +++ b/TensorFrost/src/Definitions/VulkanInterface.h @@ -72,12 +72,14 @@ class PyComputeProgram { void run(const pybind11::iterable& readonlyBuffers, const pybind11::iterable& readwriteBuffers, - uint32_t groupCount); + uint32_t groupCount, + const pybind11::object& pushConstants); void release(); uint32_t readonlyCount() const; uint32_t readwriteCount() const; + uint32_t pushConstantSize() const; private: void ensureValid() const; @@ -205,10 +207,6 @@ class PyWindow { WindowContext window_{}; }; -PyComputeProgram MakeComputeProgramFromGLSL(const std::string& source, - uint32_t roCount, - uint32_t rwCount); - PyComputeProgram MakeComputeProgramFromSlang(const std::string& moduleName, const std::string& source, const std::string& entry, diff --git a/examples/Slang/mandelbrot.py b/examples/Slang/mandelbrot.py index 6803d450..8a5e9318 100644 --- a/examples/Slang/mandelbrot.py +++ b/examples/Slang/mandelbrot.py @@ -22,10 +22,9 @@ def main() -> None: pixel_capacity = max(1, width * height) pixel_buffer = tf.createBuffer(pixel_capacity, 4, False) - params_buffer = tf.createBuffer(8, 4, True) shader_source = load_shader() - program = tf.createComputeProgramFromSlang("mandelbrot", shader_source, "csMain", ro_count=1, rw_count=1) + program = tf.createComputeProgramFromSlang("mandelbrot", shader_source, "csMain", ro_count=0, rw_count=1) local_size = 64 center = [-0.5, 0.0] @@ -136,8 +135,7 @@ def ensure_pixel_buffer(cur_width: int, cur_height: int) -> None: params[6] = float(manual_iterations) params[7] = 1.0 if swap_rb else 0.0 - params_buffer.setData(params) - program.run([params_buffer], [pixel_buffer], group_count) + program.run([], [pixel_buffer], group_count, params) win.drawBuffer(pixel_buffer, width, height) diff --git a/examples/Slang/mandelbrot.slang b/examples/Slang/mandelbrot.slang index 4470166f..bcfca9ec 100644 --- a/examples/Slang/mandelbrot.slang +++ b/examples/Slang/mandelbrot.slang @@ -1,5 +1,18 @@ -[[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); -[[vk::binding(1,0)]] RWStructuredBuffer Pixels : register(u1, space0); +struct MandelbrotParams { + float width; + float height; + float xmin; + float ymin; + float dx; + float dy; + float maxIter; + float swapRB; +}; + +[[vk::push_constant]] +MandelbrotParams gParams; + +[[vk::binding(0,0)]] RWStructuredBuffer Pixels : register(u0, space0); static float3 palette(float t) { return sin(6.3 * float3(0.0, 0.33, 0.67) * t); @@ -9,17 +22,20 @@ static float3 palette(float t) { void csMain(uint3 tid : SV_DispatchThreadID) { uint idx = tid.x; - int W = (int)(Params[0] + 0.5); - int H = (int)(Params[1] + 0.5); + int W = (int)(gParams.width + 0.5); + int H = (int)(gParams.height + 0.5); uint N = (uint)(W * H); if (idx >= N) return; int x = int(idx % (uint)W); int y = int(idx / (uint)W); - float xmin = Params[2], ymin = Params[3], dx = Params[4], dy = Params[5]; - int maxIter = (int)(Params[6] + 0.5); - bool isBGRA = (Params[7] > 0.5); + float xmin = gParams.xmin; + float ymin = gParams.ymin; + float dx = gParams.dx; + float dy = gParams.dy; + int maxIter = (int)(gParams.maxIter + 0.5); + bool isBGRA = (gParams.swapRB > 0.5); float cx = xmin + float(x) * dx; float cy = ymin + float(y) * dy; @@ -47,4 +63,4 @@ void csMain(uint3 tid : SV_DispatchThreadID) uint packed = isBGRA ? (b | (g<<8) | (r<<16) | (a<<24)) : (r | (g<<8) | (b<<16) | (a<<24)); Pixels[idx] = packed; -} \ No newline at end of file +} diff --git a/examples/debug.py b/examples/debug.py index f09492df..5f30aec1 100644 --- a/examples/debug.py +++ b/examples/debug.py @@ -1,83 +1,50 @@ import numpy as np import TensorFrost as tf -# GLSL: 1D dispatch (local_size_x=64). Pixels are packed with packUnorm4x8. -glsl = r""" -#version 450 -layout(local_size_x = 64) in; - -layout(std430, binding = 0) readonly buffer Params { float p[]; }; // [w,h,xmin,ymin,dx,dy,maxIter,isBGRA] -layout(std430, binding = 1) writeonly buffer Pixels { uint out_u32[]; }; - -vec3 palette(float t) { - // simple smooth palette - return vec3(0.5 + 0.5*cos(6.28318*(vec3(0.0,0.33,0.67)+t))); -} - -void main() { - uint idx1D = gl_GlobalInvocationID.x; - int W = int(p[0] + 0.5), H = int(p[1] + 0.5); - uint N = uint(W*H); - if (idx1D >= N) return; - - int x = int(idx1D % uint(W)); - int y = int(idx1D / uint(W)); - - float xmin = p[2], ymin = p[3], dx = p[4], dy = p[5]; - int maxIter = int(p[6] + 0.5); - bool isBGRA = (p[7] > 0.5); - - float cx = xmin + float(x) * dx; - float cy = ymin + float(y) * dy; - - float zx = 0.0, zy = 0.0; - int i = 0; - for (; i < maxIter; ++i) { - float zx2 = zx*zx - zy*zy + cx; - float zy2 = 2.0*zx*zy + cy; - zx = zx2; zy = zy2; - if (zx*zx + zy*zy > 4.0) break; - } - - float t = (i == maxIter) ? 0.0 : - float(i) - log2(log(length(vec2(zx,zy)))) + 4.0; - t = clamp(t / float(maxIter), 0.0, 1.0); - - vec3 rgb = palette(t); - vec4 c = vec4(rgb, 1.0); - uint packed = isBGRA ? packUnorm4x8(c.bgra) : packUnorm4x8(c); - out_u32[idx1D] = packed; +_SLANG = r""" +struct FillParams { + float4 color; +}; + +[[vk::push_constant]] +FillParams gParams; + +[[vk::binding(0,0)]] RWStructuredBuffer Pixels : register(u0, space0); + +[numthreads(64,1,1)] +void csMain(uint3 tid : SV_DispatchThreadID) +{ + uint idx = tid.x; + if (idx >= Pixels.length()) return; + + float4 c = saturate(gParams.color); + uint r = (uint)round(c.r * 255.0); + uint g = (uint)round(c.g * 255.0); + uint b = (uint)round(c.b * 255.0); + uint a = (uint)round(c.a * 255.0); + Pixels[idx] = r | (g << 8) | (b << 16) | (a << 24); } """ -def main(): - W, H = 1024, 768 - local_size = 64 - group_count = max((W * H + local_size - 1) // local_size, 1) - win = tf.createWindow(W, H, "Mandelbrot (compute → buffer → swapchain)") - fmt = int(win.format) - is_bgra = fmt in (44, 50) # VK_FORMAT_B8G8R8A8_UNORM / _SRGB - pix = tf.createBuffer(W*H, 4, False) # uint32 pixels - params = tf.createBuffer(8, 4, True) # 8 float32 params +def main() -> None: + width = height = 512 + local_size = 64 + thread_count = width * height + group_count = max((thread_count + local_size - 1) // local_size, 1) - prog = tf.createComputeProgramFromGLSL(glsl, ro_count=1, rw_count=1) + window = tf.createWindow(width, height, "TensorFrost Debug Fill") + pixel_buffer = tf.createBuffer(thread_count, 4, False) + program = tf.createComputeProgramFromSlang("debug_fill", _SLANG, "csMain", ro_count=0, rw_count=1) - # view rectangle with aspect correction - xspan = 3.0 - yspan = xspan * (H / float(W)) - xmin, ymin = -2.0, -yspan * 0.5 - dx, dy = xspan / W, yspan / H - max_iter = 500.0 + color = np.array([0.15, 0.45, 0.95, 1.0], dtype=np.float32) - p = np.array([float(W), float(H), xmin, ymin, dx, dy, max_iter, 1.0 if is_bgra else 0.0], dtype=np.float32) - params.setData(p) + while window.isOpen(): + program.run([], [pixel_buffer], group_count, color) + window.drawBuffer(pixel_buffer, width, height) - while win.isOpen(): - prog.run([params], [pix], group_count) - win.drawBuffer(pix, W, H) + window.close() - win.close() if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/radix_sort/shaders/bucket_scan.slang b/examples/radix_sort/shaders/bucket_scan.slang index 43d18de9..cf57aa03 100644 --- a/examples/radix_sort/shaders/bucket_scan.slang +++ b/examples/radix_sort/shaders/bucket_scan.slang @@ -1,6 +1,19 @@ -[[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); -[[vk::binding(1,0)]] StructuredBuffer GroupPrefix : register(t1, space0); -[[vk::binding(2,0)]] RWStructuredBuffer BucketScan : register(u2, space0); +struct SortParams { + uint elementCount; + uint histogramSize; + uint shift; + uint mask; + uint numGroups; + uint blockSize; + uint blockCount; + uint hasValues; +}; + +[[vk::push_constant]] +SortParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer GroupPrefix : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer BucketScan : register(u1, space0); [numthreads(64, 1, 1)] void csBucketScan(uint3 dispatchThreadID : SV_DispatchThreadID) @@ -8,8 +21,8 @@ void csBucketScan(uint3 dispatchThreadID : SV_DispatchThreadID) if (dispatchThreadID.x != 0) return; - uint histogramSize = Params[1]; - uint numGroups = Params[4]; + uint histogramSize = gParams.histogramSize; + uint numGroups = gParams.numGroups; if (histogramSize == 0) return; diff --git a/examples/radix_sort/shaders/histogram.slang b/examples/radix_sort/shaders/histogram.slang index 01e50b47..393b9c1d 100644 --- a/examples/radix_sort/shaders/histogram.slang +++ b/examples/radix_sort/shaders/histogram.slang @@ -7,20 +7,33 @@ uint packedCount(uint histogramSize) return (histogramSize + 3u) >> 2; } -[[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); -[[vk::binding(1,0)]] StructuredBuffer KeysIn : register(t1, space0); -[[vk::binding(2,0)]] RWStructuredBuffer PackedHistogram : register(u2, space0); +struct SortParams { + uint elementCount; + uint histogramSize; + uint shift; + uint mask; + uint numGroups; + uint blockSize; + uint blockCount; + uint hasValues; +}; + +[[vk::push_constant]] +SortParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer KeysIn : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer PackedHistogram : register(u1, space0); groupshared uint sPacked[MAX_HIST_SIZE / 4]; [numthreads(GROUP_SIZE, 1, 1)] void csHistogram(uint3 groupID : SV_GroupID, uint3 localID : SV_GroupThreadID) { - uint elementCount = Params[0]; - uint histogramSize = Params[1]; - uint shift = Params[2]; - uint mask = Params[3]; - uint numGroups = Params[4]; + uint elementCount = gParams.elementCount; + uint histogramSize = gParams.histogramSize; + uint shift = gParams.shift; + uint mask = gParams.mask; + uint numGroups = gParams.numGroups; if (histogramSize > MAX_HIST_SIZE) return; diff --git a/examples/radix_sort/shaders/map_from_uint.slang b/examples/radix_sort/shaders/map_from_uint.slang index 643b4509..93491154 100644 --- a/examples/radix_sort/shaders/map_from_uint.slang +++ b/examples/radix_sort/shaders/map_from_uint.slang @@ -5,19 +5,26 @@ static const uint SIGN_BIT = 0x80000000u; static const uint FULL_MASK = 0xFFFFFFFFu; static const uint GROUP_SIZE = TF_GROUP_SIZE; -[[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); -[[vk::binding(1,0)]] StructuredBuffer Input : register(t1, space0); -[[vk::binding(2,0)]] RWStructuredBuffer Output : register(u2, space0); +struct MapParams { + uint count; + uint typeCode; +}; + +[[vk::push_constant]] +MapParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer Input : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer Output : register(u1, space0); [numthreads(GROUP_SIZE, 1, 1)] void csMapFromUint(uint3 dispatchThreadID : SV_DispatchThreadID) { uint index = dispatchThreadID.x; - uint count = Params[0]; + uint count = gParams.count; if (index >= count) return; - uint typeCode = Params[1]; + uint typeCode = gParams.typeCode; uint value = Input[index]; if (typeCode == TYPE_INT) diff --git a/examples/radix_sort/shaders/map_to_uint.slang b/examples/radix_sort/shaders/map_to_uint.slang index 6046235f..9cf4863a 100644 --- a/examples/radix_sort/shaders/map_to_uint.slang +++ b/examples/radix_sort/shaders/map_to_uint.slang @@ -5,19 +5,26 @@ static const uint SIGN_BIT = 0x80000000u; static const uint FULL_MASK = 0xFFFFFFFFu; static const uint GROUP_SIZE = TF_GROUP_SIZE; -[[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); -[[vk::binding(1,0)]] StructuredBuffer Input : register(t1, space0); -[[vk::binding(2,0)]] RWStructuredBuffer Output : register(u2, space0); +struct MapParams { + uint count; + uint typeCode; +}; + +[[vk::push_constant]] +MapParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer Input : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer Output : register(u1, space0); [numthreads(GROUP_SIZE, 1, 1)] void csMapToUint(uint3 dispatchThreadID : SV_DispatchThreadID) { uint index = dispatchThreadID.x; - uint count = Params[0]; + uint count = gParams.count; if (index >= count) return; - uint typeCode = Params[1]; + uint typeCode = gParams.typeCode; uint value = Input[index]; if (typeCode == TYPE_INT) diff --git a/examples/radix_sort/shaders/prefix_accum.slang b/examples/radix_sort/shaders/prefix_accum.slang index 26e7df66..5cd2ea3f 100644 --- a/examples/radix_sort/shaders/prefix_accum.slang +++ b/examples/radix_sort/shaders/prefix_accum.slang @@ -1,14 +1,27 @@ -[[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); -[[vk::binding(1,0)]] StructuredBuffer BlockPrefix : register(t1, space0); -[[vk::binding(2,0)]] RWStructuredBuffer GroupPrefix : register(u2, space0); +struct SortParams { + uint elementCount; + uint histogramSize; + uint shift; + uint mask; + uint numGroups; + uint blockSize; + uint blockCount; + uint hasValues; +}; + +[[vk::push_constant]] +SortParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer BlockPrefix : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer GroupPrefix : register(u1, space0); [numthreads(64, 1, 1)] void csPrefixAccumulate(uint3 dispatchThreadID : SV_DispatchThreadID) { - uint histogramSize = Params[1]; - uint numGroups = Params[4]; - uint blockSize = Params[5]; - uint blockCount = Params[6]; + uint histogramSize = gParams.histogramSize; + uint numGroups = gParams.numGroups; + uint blockSize = gParams.blockSize; + uint blockCount = gParams.blockCount; uint totalThreads = blockCount * histogramSize; uint index = dispatchThreadID.x; diff --git a/examples/radix_sort/shaders/prefix_block.slang b/examples/radix_sort/shaders/prefix_block.slang index 5c526574..b4d714ee 100644 --- a/examples/radix_sort/shaders/prefix_block.slang +++ b/examples/radix_sort/shaders/prefix_block.slang @@ -1,12 +1,25 @@ -[[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); -[[vk::binding(1,0)]] StructuredBuffer BlockTotals : register(t1, space0); -[[vk::binding(2,0)]] RWStructuredBuffer BlockPrefix : register(u2, space0); +struct SortParams { + uint elementCount; + uint histogramSize; + uint shift; + uint mask; + uint numGroups; + uint blockSize; + uint blockCount; + uint hasValues; +}; + +[[vk::push_constant]] +SortParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer BlockTotals : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer BlockPrefix : register(u1, space0); [numthreads(64, 1, 1)] void csPrefixBlocks(uint3 dispatchThreadID : SV_DispatchThreadID) { - uint histogramSize = Params[1]; - uint blockCount = Params[6]; + uint histogramSize = gParams.histogramSize; + uint blockCount = gParams.blockCount; uint bucket = dispatchThreadID.x; if (bucket >= histogramSize) return; diff --git a/examples/radix_sort/shaders/prefix_local.slang b/examples/radix_sort/shaders/prefix_local.slang index 14bed0de..1f87f31a 100644 --- a/examples/radix_sort/shaders/prefix_local.slang +++ b/examples/radix_sort/shaders/prefix_local.slang @@ -1,15 +1,28 @@ -[[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); -[[vk::binding(1,0)]] StructuredBuffer GroupHistogram : register(t1, space0); -[[vk::binding(2,0)]] RWStructuredBuffer GroupPrefix : register(u2, space0); -[[vk::binding(3,0)]] RWStructuredBuffer BlockTotals : register(u3, space0); +struct SortParams { + uint elementCount; + uint histogramSize; + uint shift; + uint mask; + uint numGroups; + uint blockSize; + uint blockCount; + uint hasValues; +}; + +[[vk::push_constant]] +SortParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer GroupHistogram : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer GroupPrefix : register(u1, space0); +[[vk::binding(2,0)]] RWStructuredBuffer BlockTotals : register(u2, space0); [numthreads(64, 1, 1)] void csPrefixLocal(uint3 dispatchThreadID : SV_DispatchThreadID) { - uint histogramSize = Params[1]; - uint numGroups = Params[4]; - uint blockSize = Params[5]; - uint blockCount = Params[6]; + uint histogramSize = gParams.histogramSize; + uint numGroups = gParams.numGroups; + uint blockSize = gParams.blockSize; + uint blockCount = gParams.blockCount; uint totalThreads = blockCount * histogramSize; uint index = dispatchThreadID.x; diff --git a/examples/radix_sort/shaders/scatter.slang b/examples/radix_sort/shaders/scatter.slang index f3b7a11a..dda8d909 100644 --- a/examples/radix_sort/shaders/scatter.slang +++ b/examples/radix_sort/shaders/scatter.slang @@ -2,13 +2,26 @@ static const uint GROUP_SIZE = TF_GROUP_SIZE; static const uint QUARTER_SIZE = GROUP_SIZE / 4; static const uint HISTOGRAM_SIZE = TF_HISTOGRAM_SIZE; -[[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); -[[vk::binding(1,0)]] StructuredBuffer KeysIn : register(t1, space0); -[[vk::binding(2,0)]] StructuredBuffer ValuesIn : register(t2, space0); -[[vk::binding(3,0)]] StructuredBuffer GroupPrefix : register(t3, space0); -[[vk::binding(4,0)]] StructuredBuffer BucketScan : register(t4, space0); -[[vk::binding(5,0)]] RWStructuredBuffer KeysOut : register(u5, space0); -[[vk::binding(6,0)]] RWStructuredBuffer ValuesOut : register(u6, space0); +struct SortParams { + uint elementCount; + uint histogramSize; + uint shift; + uint mask; + uint numGroups; + uint blockSize; + uint blockCount; + uint hasValues; +}; + +[[vk::push_constant]] +SortParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer KeysIn : register(t0, space0); +[[vk::binding(1,0)]] StructuredBuffer ValuesIn : register(t1, space0); +[[vk::binding(2,0)]] StructuredBuffer GroupPrefix : register(t2, space0); +[[vk::binding(3,0)]] StructuredBuffer BucketScan : register(t3, space0); +[[vk::binding(4,0)]] RWStructuredBuffer KeysOut : register(u4, space0); +[[vk::binding(5,0)]] RWStructuredBuffer ValuesOut : register(u5, space0); groupshared uint tempBits[GROUP_SIZE]; groupshared uint halfCount[HISTOGRAM_SIZE]; @@ -16,11 +29,11 @@ groupshared uint halfCount[HISTOGRAM_SIZE]; [numthreads(GROUP_SIZE, 1, 1)] void csScatter(uint3 groupID : SV_GroupID, uint3 localID : SV_GroupThreadID) { - uint elementCount = Params[0]; - uint shift = Params[2]; - uint mask = Params[3]; - uint numGroups = Params[4]; - uint hasValues = Params[7]; + uint elementCount = gParams.elementCount; + uint shift = gParams.shift; + uint mask = gParams.mask; + uint numGroups = gParams.numGroups; + uint hasValues = gParams.hasValues; uint group = groupID.x; if (group >= numGroups) diff --git a/examples/radix_sort/shaders/unpack.slang b/examples/radix_sort/shaders/unpack.slang index 03eb29af..49c0ea83 100644 --- a/examples/radix_sort/shaders/unpack.slang +++ b/examples/radix_sort/shaders/unpack.slang @@ -3,15 +3,28 @@ uint packedCount(uint histogramSize) return (histogramSize + 3u) >> 2; } -[[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); -[[vk::binding(1,0)]] StructuredBuffer PackedHistogram : register(t1, space0); -[[vk::binding(2,0)]] RWStructuredBuffer GroupHistogram : register(u2, space0); +struct SortParams { + uint elementCount; + uint histogramSize; + uint shift; + uint mask; + uint numGroups; + uint blockSize; + uint blockCount; + uint hasValues; +}; + +[[vk::push_constant]] +SortParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer PackedHistogram : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer GroupHistogram : register(u1, space0); [numthreads(64, 1, 1)] void csUnpack(uint3 dispatchThreadID : SV_DispatchThreadID) { - uint histogramSize = Params[1]; - uint numGroups = Params[4]; + uint histogramSize = gParams.histogramSize; + uint numGroups = gParams.numGroups; uint packedCountLocal = packedCount(histogramSize); uint total = numGroups * histogramSize; diff --git a/examples/radix_sort/shaders/validate_sorted.slang b/examples/radix_sort/shaders/validate_sorted.slang index c709bbbb..e5edc419 100644 --- a/examples/radix_sort/shaders/validate_sorted.slang +++ b/examples/radix_sort/shaders/validate_sorted.slang @@ -5,9 +5,16 @@ static const uint SIGN_BIT = 0x80000000u; static const uint FULL_MASK = 0xFFFFFFFFu; static const uint GROUP_SIZE = TF_GROUP_SIZE; -[[vk::binding(0,0)]] StructuredBuffer Params : register(t0, space0); // [0] = count, [1] = typeCode -[[vk::binding(1,0)]] StructuredBuffer Keys : register(t1, space0); // original key bit patterns -[[vk::binding(2,0)]] RWStructuredBuffer Errors : register(u2, space0); // single uint atomic counter +struct ValidateParams { + uint elementCount; + uint typeCode; +}; + +[[vk::push_constant]] +ValidateParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer Keys : register(t0, space0); // original key bit patterns +[[vk::binding(1,0)]] RWStructuredBuffer Errors : register(u1, space0); // single uint atomic counter // Map a key to sortable uint (same transform as map_to_uint.slang) uint mapKey(uint raw, uint typeCode) @@ -28,11 +35,11 @@ uint mapKey(uint raw, uint typeCode) void csValidate(uint3 dispatchThreadID : SV_DispatchThreadID) { uint i = dispatchThreadID.x; - uint n = Params[0]; + uint n = gParams.elementCount; if (n < 2 || i >= n - 1) return; - uint typeCode = Params[1]; + uint typeCode = gParams.typeCode; uint a = mapKey(Keys[i], typeCode); uint b = mapKey(Keys[i + 1u], typeCode); @@ -41,4 +48,4 @@ void csValidate(uint3 dispatchThreadID : SV_DispatchThreadID) { InterlockedAdd(Errors[0], 1u); } -} \ No newline at end of file +} diff --git a/examples/radix_sort/sort.py b/examples/radix_sort/sort.py index 49aeff10..277feeb3 100644 --- a/examples/radix_sort/sort.py +++ b/examples/radix_sort/sort.py @@ -104,14 +104,14 @@ def inject_defines(filename: str, *, with_group: bool = False, with_histogram: b "radix_map_to_uint", inject_defines("map_to_uint.slang", with_group=True), "csMapToUint", - ro_count=2, + ro_count=1, rw_count=1, ) self._map_from_uint_program = tf.createComputeProgramFromSlang( "radix_map_from_uint", inject_defines("map_from_uint.slang", with_group=True), "csMapFromUint", - ro_count=2, + ro_count=1, rw_count=1, ) @@ -119,42 +119,42 @@ def inject_defines(filename: str, *, with_group: bool = False, with_histogram: b "radix_histogram", inject_defines("histogram.slang", with_group=True), "csHistogram", - ro_count=2, + ro_count=1, rw_count=1, ) self._unpack_program = tf.createComputeProgramFromSlang( "radix_unpack", _load_shader_source("unpack.slang"), "csUnpack", - ro_count=2, + ro_count=1, rw_count=1, ) self._prefix_local_program = tf.createComputeProgramFromSlang( "radix_prefix_local", _load_shader_source("prefix_local.slang"), "csPrefixLocal", - ro_count=2, + ro_count=1, rw_count=2, ) self._prefix_blocks_program = tf.createComputeProgramFromSlang( "radix_prefix_blocks", _load_shader_source("prefix_block.slang"), "csPrefixBlocks", - ro_count=2, + ro_count=1, rw_count=1, ) self._prefix_accum_program = tf.createComputeProgramFromSlang( "radix_prefix_accum", _load_shader_source("prefix_accum.slang"), "csPrefixAccumulate", - ro_count=2, + ro_count=1, rw_count=1, ) self._bucket_scan_program = tf.createComputeProgramFromSlang( "radix_bucket_scan", _load_shader_source("bucket_scan.slang"), "csBucketScan", - ro_count=2, + ro_count=1, rw_count=1, ) scatter_source = inject_defines("scatter.slang", with_group=True, with_histogram=True) @@ -162,7 +162,7 @@ def inject_defines(filename: str, *, with_group: bool = False, with_histogram: b "radix_scatter", scatter_source, "csScatter", - ro_count=5, + ro_count=4, rw_count=2, ) @@ -171,7 +171,7 @@ def inject_defines(filename: str, *, with_group: bool = False, with_histogram: b "radix_validate_sorted", inject_defines("validate_sorted.slang", with_group=True), "csValidate", - ro_count=2, + ro_count=1, rw_count=1, ) @@ -225,7 +225,7 @@ def sort( params_array[6] = np.uint32(block_count) params_array[7] = np.uint32(1 if values_array is not None else 0) - map_params = np.zeros(4, dtype=np.uint32) + map_params = np.zeros(2, dtype=np.uint32) map_params[0] = np.uint32(element_count) map_params[1] = _TYPE_CODES[key_kind] @@ -244,12 +244,6 @@ def sort( else: stage_totals = {} - params_buffer = tf.createBuffer(params_array.size, 4, True) - params_buffer.setData(params_array) - - map_buffer = tf.createBuffer(map_params.size, 4, True) - map_buffer.setData(map_params) - key_buffers = [tf.createBuffer(max(element_count, 1), 4, False) for _ in range(2)] key_buffers[0].setData(keys_array) @@ -281,9 +275,10 @@ def sort( total_start = time.perf_counter() if collect_stage_timings else None start = time.perf_counter() if collect_stage_timings else None self._map_to_uint_program.run( - [map_buffer, key_buffers[0]], + [key_buffers[0]], [key_buffers[1]], map_groups, + map_params, ) if collect_stage_timings and start is not None: stage_totals["map_to_uint"] += time.perf_counter() - start @@ -294,67 +289,73 @@ def sort( for pass_index in range(passes): params_array[2] = np.uint32(pass_index * self.bits_per_pass) - params_buffer.setData(params_array) start = time.perf_counter() if collect_stage_timings else None self._histogram_program.run( - [params_buffer, key_in], + [key_in], [packed_hist_buffer], histogram_groups, + params_array, ) if collect_stage_timings and start is not None: stage_totals["histogram"] += time.perf_counter() - start start = time.perf_counter() if collect_stage_timings else None self._unpack_program.run( - [params_buffer, packed_hist_buffer], + [packed_hist_buffer], [group_hist_buffer], unpack_groups, + params_array, ) if collect_stage_timings and start is not None: stage_totals["unpack"] += time.perf_counter() - start start = time.perf_counter() if collect_stage_timings else None self._prefix_local_program.run( - [params_buffer, group_hist_buffer], + [group_hist_buffer], [prefix_buffer, block_totals_buffer], prefix_local_groups, + params_array, ) if collect_stage_timings and start is not None: stage_totals["prefix_local"] += time.perf_counter() - start start = time.perf_counter() if collect_stage_timings else None self._prefix_blocks_program.run( - [params_buffer, block_totals_buffer], + [block_totals_buffer], [block_prefix_buffer], prefix_block_groups, + params_array, ) if collect_stage_timings and start is not None: stage_totals["prefix_blocks"] += time.perf_counter() - start start = time.perf_counter() if collect_stage_timings else None self._prefix_accum_program.run( - [params_buffer, block_prefix_buffer], + [block_prefix_buffer], [prefix_buffer], prefix_accum_groups, + params_array, ) if collect_stage_timings and start is not None: stage_totals["prefix_accum"] += time.perf_counter() - start start = time.perf_counter() if collect_stage_timings else None self._bucket_scan_program.run( - [params_buffer, prefix_buffer], + [prefix_buffer], [bucket_scan_buffer], bucket_scan_groups, + params_array, ) if collect_stage_timings and start is not None: stage_totals["bucket_scan"] += time.perf_counter() - start start = time.perf_counter() if collect_stage_timings else None self._scatter_program.run( - [params_buffer, key_in, val_in, prefix_buffer, bucket_scan_buffer], + [key_in, val_in, prefix_buffer, bucket_scan_buffer], [key_out, val_out], scatter_groups, + params_array, ) if collect_stage_timings and start is not None: stage_totals["scatter"] += time.perf_counter() - start @@ -365,9 +366,10 @@ def sort( start = time.perf_counter() if collect_stage_timings else None self._map_from_uint_program.run( - [map_buffer, key_in], + [key_in], [key_out], map_groups, + map_params, ) if collect_stage_timings and start is not None: stage_totals["map_from_uint"] += time.perf_counter() - start @@ -382,18 +384,16 @@ def sort( validate_params[0] = np.uint32(element_count) validate_params[1] = map_params[1] # type code - validate_params_buf = tf.createBuffer(validate_params.size, 4, True) - validate_params_buf.setData(validate_params) - error_buf = tf.createBuffer(1, 4, False) error_zero = np.zeros(1, dtype=np.uint32) error_buf.setData(error_zero) # Reuse map_groups; kernel early-outs for i >= n-1 self._validate_program.run( - [validate_params_buf, key_out], + [key_out], [error_buf], map_groups, + validate_params, ) error_count = int(error_buf.getData(np.dtype(np.uint32), 1)[0]) @@ -444,4 +444,4 @@ def radix_sort( sorted_keys, sorted_values = sorter.sort(keys, values, max_bits=max_bits) if values is None: return sorted_keys - return sorted_keys, sorted_values \ No newline at end of file + return sorted_keys, sorted_values diff --git a/tests/vulkan_window_test.py b/tests/vulkan_window_test.py index ebdb8f80..6d1169b5 100644 --- a/tests/vulkan_window_test.py +++ b/tests/vulkan_window_test.py @@ -5,14 +5,15 @@ import TensorFrost as tf -_SIMPLE_GLSL = """#version 450 -layout(local_size_x = 64) in; -layout(set = 0, binding = 0) buffer Pixels { uint data[]; }; - -void main() { - uint idx = gl_GlobalInvocationID.x; - if (idx >= data.length()) return; - data[idx] = 0xff3366ff; +_SIMPLE_SLANG = r""" +[[vk::binding(0,0)]] RWStructuredBuffer Pixels : register(u0, space0); + +[numthreads(64,1,1)] +void csMain(uint3 tid : SV_DispatchThreadID) +{ + uint idx = tid.x; + if (idx >= Pixels.length()) return; + Pixels[idx] = 0xff3366ff; } """ @@ -31,7 +32,9 @@ def test_compute_dispatch_and_window_present(self): program = None try: - program = tf.createComputeProgramFromGLSL(_SIMPLE_GLSL, ro_count=0, rw_count=1) + program = tf.createComputeProgramFromSlang( + "window_test_fill", _SIMPLE_SLANG, "csMain", ro_count=0, rw_count=1 + ) except RuntimeError as exc: # pragma: no cover - Vulkan program creation failed self.skipTest(f"Vulkan program creation failed: {exc}") From dd298782609f20dadec5302b28537d1298215950 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Sun, 9 Nov 2025 06:04:14 +0100 Subject: [PATCH 43/44] Remove push const reflection --- Python/TensorFrost/sort.py | 8 ++++ TensorFrost/Backend/include/Backend/Vulkan.h | 3 +- TensorFrost/Backend/src/Vulkan.cpp | 43 ++++--------------- .../src/Definitions/VulkanBindings.cpp | 12 +++++- .../src/Definitions/VulkanInterface.cpp | 6 ++- TensorFrost/src/Definitions/VulkanInterface.h | 3 +- examples/Slang/mandelbrot.py | 9 +++- examples/debug.py | 9 +++- examples/radix_sort/__main__.py | 42 +++++++++++++++--- examples/radix_sort/sort.py | 10 +++++ 10 files changed, 97 insertions(+), 48 deletions(-) diff --git a/Python/TensorFrost/sort.py b/Python/TensorFrost/sort.py index ef5e66e5..4a236bb9 100644 --- a/Python/TensorFrost/sort.py +++ b/Python/TensorFrost/sort.py @@ -92,6 +92,7 @@ def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: "csMapToUint", ro_count=1, rw_count=1, + push_constant_size=8, ) self._map_from_uint_program = tf.createComputeProgramFromSlang( "radix_map_from_uint", @@ -99,6 +100,7 @@ def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: "csMapFromUint", ro_count=1, rw_count=1, + push_constant_size=8, ) self._histogram_program = tf.createComputeProgramFromSlang( @@ -107,6 +109,7 @@ def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: "csHistogram", ro_count=1, rw_count=1, + push_constant_size=32, ) self._unpack_program = tf.createComputeProgramFromSlang( "radix_unpack", @@ -114,6 +117,7 @@ def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: "csUnpack", ro_count=1, rw_count=1, + push_constant_size=32, ) self._prefix_local_program = tf.createComputeProgramFromSlang( "radix_prefix_local", @@ -121,6 +125,7 @@ def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: "csPrefixLocal", ro_count=1, rw_count=2, + push_constant_size=32, ) self._prefix_blocks_program = tf.createComputeProgramFromSlang( "radix_prefix_blocks", @@ -128,6 +133,7 @@ def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: "csPrefixBlocks", ro_count=1, rw_count=1, + push_constant_size=32, ) self._prefix_accum_program = tf.createComputeProgramFromSlang( "radix_prefix_accum", @@ -135,6 +141,7 @@ def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: "csPrefixAccumulate", ro_count=1, rw_count=1, + push_constant_size=32, ) self._bucket_scan_program = tf.createComputeProgramFromSlang( "radix_bucket_scan", @@ -142,6 +149,7 @@ def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: "csBucketScan", ro_count=1, rw_count=1, + push_constant_size=32, ) scatter_source = f"#define TF_HISTOGRAM_SIZE {self.histogram_size}u\n" + _load_shader_source("scatter.slang") self._scatter_program = tf.createComputeProgramFromSlang( diff --git a/TensorFrost/Backend/include/Backend/Vulkan.h b/TensorFrost/Backend/include/Backend/Vulkan.h index 071a78f1..d61041d2 100644 --- a/TensorFrost/Backend/include/Backend/Vulkan.h +++ b/TensorFrost/Backend/include/Backend/Vulkan.h @@ -90,7 +90,8 @@ void setBufferData(Buffer& buf, const void* src, size_t bytes, size_t offset = 0 void getBufferData(const Buffer& buf, void* dst, size_t bytes, size_t offset = 0); ComputeProgram createComputeProgramFromSlang(const std::string& moduleName, - const std::string& source, const std::string& entry, uint32_t roCount, uint32_t rwCount); + const std::string& source, const std::string& entry, + uint32_t roCount, uint32_t rwCount, uint32_t pushConstantSize = 0); void destroyComputeProgram(ComputeProgram& prog); void runProgram(const ComputeProgram& prog, diff --git a/TensorFrost/Backend/src/Vulkan.cpp b/TensorFrost/Backend/src/Vulkan.cpp index a36534df..11794e1a 100644 --- a/TensorFrost/Backend/src/Vulkan.cpp +++ b/TensorFrost/Backend/src/Vulkan.cpp @@ -445,15 +445,10 @@ VulkanContext::~VulkanContext() { instance.destroy(); } -struct SlangCompileResult { - std::vector spirv; - uint32_t pushConstantSize = 0; -}; - -SlangCompileResult compileSlangToSpirv(const char* moduleName, - const char* source, - const char* entry, - const char* profile /* e.g., "spirv_1_5" */) { +std::vector compileSlangToSpirv(const char* moduleName, + const char* source, + const char* entry, + const char* profile /* e.g., "spirv_1_5" */) { Slang::ComPtr global; createGlobalSession(global.writeRef()); @@ -512,32 +507,11 @@ SlangCompileResult compileSlangToSpirv(const char* moduleName, if (SLANG_FAILED(r)) throw std::runtime_error("slang: getEntryPointCode failed"); } - uint32_t pushConstantSize = 0; - { - Slang::ComPtr diag; - slang::ProgramLayout* layout = linked->getLayout(0, diag.writeRef()); - if (diag && diag->getBufferSize()) std::fprintf(stderr, "%s\n", (const char*)diag->getBufferPointer()); - if (!layout) throw std::runtime_error("slang: failed to obtain program layout"); - - if (auto* globalLayout = layout->getGlobalParamsTypeLayout()) { - size_t size = globalLayout->getSize(slang::ParameterCategory::PushConstantBuffer); - pushConstantSize = std::max(pushConstantSize, static_cast(size)); - } - for (SlangUInt i = 0; i < layout->getEntryPointCount(); ++i) { - if (auto* entry = layout->getEntryPointByIndex(i)) { - if (auto* typeLayout = entry->getTypeLayout()) { - size_t size = typeLayout->getSize(slang::ParameterCategory::PushConstantBuffer); - pushConstantSize = std::max(pushConstantSize, static_cast(size)); - } - } - } - } - size_t n = spirv->getBufferSize(); auto* p = static_cast(spirv->getBufferPointer()); std::vector out((n + 3) / 4); std::memcpy(out.data(), p, n); - return {std::move(out), pushConstantSize}; + return out; } ComputeBindings createBindings(VulkanContext& ctx, const ComputeProgram& prog, @@ -605,9 +579,10 @@ static ComputeProgram createComputeProgram(const std::vector& spirv, } ComputeProgram createComputeProgramFromSlang(const std::string& moduleName, - const std::string& source, const std::string& entry, uint32_t roCount, uint32_t rwCount) { - auto result = compileSlangToSpirv(moduleName.c_str(), source.c_str(), entry.c_str(), "spirv_1_5"); - return createComputeProgram(result.spirv, roCount, rwCount, result.pushConstantSize); + const std::string& source, const std::string& entry, + uint32_t roCount, uint32_t rwCount, uint32_t pushConstantSize) { + auto spirv = compileSlangToSpirv(moduleName.c_str(), source.c_str(), entry.c_str(), "spirv_1_5"); + return createComputeProgram(spirv, roCount, rwCount, pushConstantSize); } void destroyComputeProgram(ComputeProgram& prog) { diff --git a/TensorFrost/src/Definitions/VulkanBindings.cpp b/TensorFrost/src/Definitions/VulkanBindings.cpp index c2f63e58..cb30e428 100644 --- a/TensorFrost/src/Definitions/VulkanBindings.cpp +++ b/TensorFrost/src/Definitions/VulkanBindings.cpp @@ -2,6 +2,7 @@ #include "VulkanInterface.h" #include +#include #include #include @@ -58,11 +59,18 @@ void VulkanDefinitions(py::module_& m) { "Explicitly destroy the underlying Vulkan pipeline and associated resources."); m.def("createComputeProgramFromSlang", - [](const std::string& moduleName, const std::string& source, const std::string& entry, uint32_t roCount, uint32_t rwCount) { - return MakeComputeProgramFromSlang(moduleName, source, entry, roCount, rwCount); + [](const std::string& moduleName, + const std::string& source, + const std::string& entry, + uint32_t roCount, + uint32_t rwCount, + uint32_t pushConstantSize) { + return MakeComputeProgramFromSlang( + moduleName, source, entry, roCount, rwCount, pushConstantSize); }, py::arg("module_name"), py::arg("source"), py::arg("entry"), py::arg("ro_count"), py::arg("rw_count"), + py::arg("push_constant_size") = 0, py::return_value_policy::move, "Compile a Slang module to SPIR-V and wrap it in a :class:`ComputeProgram`."); diff --git a/TensorFrost/src/Definitions/VulkanInterface.cpp b/TensorFrost/src/Definitions/VulkanInterface.cpp index d22bd12d..9cab1364 100644 --- a/TensorFrost/src/Definitions/VulkanInterface.cpp +++ b/TensorFrost/src/Definitions/VulkanInterface.cpp @@ -743,8 +743,10 @@ PyComputeProgram MakeComputeProgramFromSlang(const std::string& moduleName, const std::string& source, const std::string& entry, uint32_t roCount, - uint32_t rwCount) { - return PyComputeProgram(createComputeProgramFromSlang(moduleName, source, entry, roCount, rwCount)); + uint32_t rwCount, + uint32_t pushConstantSize) { + return PyComputeProgram(createComputeProgramFromSlang( + moduleName, source, entry, roCount, rwCount, pushConstantSize)); } } // namespace TensorFrost diff --git a/TensorFrost/src/Definitions/VulkanInterface.h b/TensorFrost/src/Definitions/VulkanInterface.h index f2c2ac2a..8b462151 100644 --- a/TensorFrost/src/Definitions/VulkanInterface.h +++ b/TensorFrost/src/Definitions/VulkanInterface.h @@ -211,6 +211,7 @@ PyComputeProgram MakeComputeProgramFromSlang(const std::string& moduleName, const std::string& source, const std::string& entry, uint32_t roCount, - uint32_t rwCount); + uint32_t rwCount, + uint32_t pushConstantSize = 0); } // namespace TensorFrost diff --git a/examples/Slang/mandelbrot.py b/examples/Slang/mandelbrot.py index 8a5e9318..f8021ca5 100644 --- a/examples/Slang/mandelbrot.py +++ b/examples/Slang/mandelbrot.py @@ -24,7 +24,14 @@ def main() -> None: pixel_buffer = tf.createBuffer(pixel_capacity, 4, False) shader_source = load_shader() - program = tf.createComputeProgramFromSlang("mandelbrot", shader_source, "csMain", ro_count=0, rw_count=1) + program = tf.createComputeProgramFromSlang( + "mandelbrot", + shader_source, + "csMain", + ro_count=0, + rw_count=1, + push_constant_size=32, + ) local_size = 64 center = [-0.5, 0.0] diff --git a/examples/debug.py b/examples/debug.py index 5f30aec1..99213136 100644 --- a/examples/debug.py +++ b/examples/debug.py @@ -35,7 +35,14 @@ def main() -> None: window = tf.createWindow(width, height, "TensorFrost Debug Fill") pixel_buffer = tf.createBuffer(thread_count, 4, False) - program = tf.createComputeProgramFromSlang("debug_fill", _SLANG, "csMain", ro_count=0, rw_count=1) + program = tf.createComputeProgramFromSlang( + "debug_fill", + _SLANG, + "csMain", + ro_count=0, + rw_count=1, + push_constant_size=16, + ) color = np.array([0.15, 0.45, 0.95, 1.0], dtype=np.float32) diff --git a/examples/radix_sort/__main__.py b/examples/radix_sort/__main__.py index 94260244..60c321ec 100644 --- a/examples/radix_sort/__main__.py +++ b/examples/radix_sort/__main__.py @@ -131,14 +131,15 @@ def main() -> None: frame_time_total = sum(frame_times) fps = (len(frame_times) / frame_time_total) if frame_times and frame_time_total > 0.0 else 0.0 - # Perform sort; on the first run also validate on GPU and avoid full array readback. + # Perform sort; on the first run also validate on GPU and read back buffers for NumPy comparison. do_validate = not validated + return_arrays = do_validate _keys_out, _vals_out = sorter.sort( keys, values, collect_stage_timings=True, validate=do_validate, - return_arrays=False, + return_arrays=return_arrays, ) stage_timings = sorter.last_stage_timings or {} # Separate total_pass (overall) from per-stage summed time @@ -157,10 +158,39 @@ def main() -> None: if do_validate: errors = int(getattr(sorter, "last_validation_errors", 0) or 0) - validation_ok = (errors == 0) - validation_message = ( - "GPU validation passed (sorted)." if validation_ok else f"GPU validation failed: {errors} out-of-order pairs" - ) + gpu_validation_ok = (errors == 0) + + cpu_validation_ok = True + cpu_messages = [] + if return_arrays: + if count: + reference_indices = np.argsort(keys, kind="stable") + reference_keys = keys[reference_indices] + reference_values = values[reference_indices] if values is not None else None + else: + reference_keys = np.copy(keys) + reference_values = np.copy(values) if values is not None else None + + if not np.array_equal(_keys_out, reference_keys): + cpu_validation_ok = False + cpu_messages.append("key mismatch") + + if values is not None and _vals_out is not None and not np.array_equal(_vals_out, reference_values): + cpu_validation_ok = False + cpu_messages.append("value mismatch") + + validation_ok = gpu_validation_ok and cpu_validation_ok + if validation_ok: + validation_message = "GPU + NumPy validation passed." + else: + failure_reasons = [] + if not gpu_validation_ok: + failure_reasons.append(f"GPU reported {errors} out-of-order pairs") + if not cpu_validation_ok: + reason = ", ".join(cpu_messages) if cpu_messages else "NumPy comparison failed" + failure_reasons.append(reason) + validation_message = "Validation failed: " + "; ".join(failure_reasons) + validated = True # No large array readback performed when return_arrays=False. diff --git a/examples/radix_sort/sort.py b/examples/radix_sort/sort.py index 277feeb3..08ed3d85 100644 --- a/examples/radix_sort/sort.py +++ b/examples/radix_sort/sort.py @@ -106,6 +106,7 @@ def inject_defines(filename: str, *, with_group: bool = False, with_histogram: b "csMapToUint", ro_count=1, rw_count=1, + push_constant_size=8, ) self._map_from_uint_program = tf.createComputeProgramFromSlang( "radix_map_from_uint", @@ -113,6 +114,7 @@ def inject_defines(filename: str, *, with_group: bool = False, with_histogram: b "csMapFromUint", ro_count=1, rw_count=1, + push_constant_size=8, ) self._histogram_program = tf.createComputeProgramFromSlang( @@ -121,6 +123,7 @@ def inject_defines(filename: str, *, with_group: bool = False, with_histogram: b "csHistogram", ro_count=1, rw_count=1, + push_constant_size=32, ) self._unpack_program = tf.createComputeProgramFromSlang( "radix_unpack", @@ -128,6 +131,7 @@ def inject_defines(filename: str, *, with_group: bool = False, with_histogram: b "csUnpack", ro_count=1, rw_count=1, + push_constant_size=32, ) self._prefix_local_program = tf.createComputeProgramFromSlang( "radix_prefix_local", @@ -135,6 +139,7 @@ def inject_defines(filename: str, *, with_group: bool = False, with_histogram: b "csPrefixLocal", ro_count=1, rw_count=2, + push_constant_size=32, ) self._prefix_blocks_program = tf.createComputeProgramFromSlang( "radix_prefix_blocks", @@ -142,6 +147,7 @@ def inject_defines(filename: str, *, with_group: bool = False, with_histogram: b "csPrefixBlocks", ro_count=1, rw_count=1, + push_constant_size=32, ) self._prefix_accum_program = tf.createComputeProgramFromSlang( "radix_prefix_accum", @@ -149,6 +155,7 @@ def inject_defines(filename: str, *, with_group: bool = False, with_histogram: b "csPrefixAccumulate", ro_count=1, rw_count=1, + push_constant_size=32, ) self._bucket_scan_program = tf.createComputeProgramFromSlang( "radix_bucket_scan", @@ -156,6 +163,7 @@ def inject_defines(filename: str, *, with_group: bool = False, with_histogram: b "csBucketScan", ro_count=1, rw_count=1, + push_constant_size=32, ) scatter_source = inject_defines("scatter.slang", with_group=True, with_histogram=True) self._scatter_program = tf.createComputeProgramFromSlang( @@ -164,6 +172,7 @@ def inject_defines(filename: str, *, with_group: bool = False, with_histogram: b "csScatter", ro_count=4, rw_count=2, + push_constant_size=32, ) # Validation program: checks if adjacent key pairs are sorted (after mapping back to original type) @@ -173,6 +182,7 @@ def inject_defines(filename: str, *, with_group: bool = False, with_histogram: b "csValidate", ro_count=1, rw_count=1, + push_constant_size=8, ) self._dummy_values_buffer = tf.createBuffer(1, 4, False) From 2d98257220422933da5eaf4d595a1fb62d8a5481 Mon Sep 17 00:00:00 2001 From: Mykhailo Moroz <47035925+MichaelMoroz@users.noreply.github.com> Date: Mon, 22 Dec 2025 16:52:44 +0100 Subject: [PATCH 44/44] counting sort --- .idea/copilot.data.migration.agent.xml | 2 +- AGENTS.md | 5 +- examples/counting_sort/__main__.py | 273 +++++++++++++ .../shaders/block_prefix_stage2.slang | 24 ++ .../counting_sort/shaders/block_sum.slang | 37 ++ .../shaders/histogram_rank.slang | 29 ++ examples/counting_sort/shaders/scatter.slang | 47 +++ .../shaders/validate_sorted.slang | 27 ++ examples/counting_sort/sort.py | 362 ++++++++++++++++++ 9 files changed, 803 insertions(+), 3 deletions(-) create mode 100644 examples/counting_sort/__main__.py create mode 100644 examples/counting_sort/shaders/block_prefix_stage2.slang create mode 100644 examples/counting_sort/shaders/block_sum.slang create mode 100644 examples/counting_sort/shaders/histogram_rank.slang create mode 100644 examples/counting_sort/shaders/scatter.slang create mode 100644 examples/counting_sort/shaders/validate_sorted.slang create mode 100644 examples/counting_sort/sort.py diff --git a/.idea/copilot.data.migration.agent.xml b/.idea/copilot.data.migration.agent.xml index 89c07510..4ea72a91 100644 --- a/.idea/copilot.data.migration.agent.xml +++ b/.idea/copilot.data.migration.agent.xml @@ -1,6 +1,6 @@ - \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md index ce9710e9..f15737e5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -5,5 +5,6 @@ Follow these expectations whenever you work in this repository: 1. **Full rebuild & virtual environment** — Run `setup_python_env.cmd` from the repo root. It configures the Python virtual environment and performs a clean rebuild so you start from a consistent state. 2. **Partial rebuilds** — Use CMake for incremental builds. Invoke the appropriate CMake build command (for example, `cmake --build --target `) to rebuild only what you need. 3. **C++ changes** — Any edits under `TensorFrost/` or other C++ sources require a rebuild before the changes take effect. -4. **API validation** — After modifying functionality, run the relevant tests in the `tests/` folder to confirm the Python API still behaves as expected. -5. **Scenario validation** — Run the sample programs in the `examples/` folder to make sure the updated stack handles more complex end-to-end flows. +4. **Python script changes** - After edits of python script, you should run them to make sure they work correctly. No recompilation needed. +5. **API validation** — After modifying functionality, run the relevant tests in the `tests/` folder to confirm the Python API still behaves as expected. +6. **Scenario validation** — Run the sample programs in the `examples/` folder to make sure the updated stack handles more complex end-to-end flows. diff --git a/examples/counting_sort/__main__.py b/examples/counting_sort/__main__.py new file mode 100644 index 00000000..2a77bd09 --- /dev/null +++ b/examples/counting_sort/__main__.py @@ -0,0 +1,273 @@ +from __future__ import annotations + +import argparse +import math +import time +from collections import deque +from pathlib import Path + +import numpy as np +import TensorFrost as tf + +try: + from .sort import CountingSort +except ImportError: + import sys + + _CURRENT_DIR = Path(__file__).resolve().parent + if str(_CURRENT_DIR) not in sys.path: + sys.path.insert(0, str(_CURRENT_DIR)) + + from sort import CountingSort + + +_STAGE_NAMES = ( + "histogram", + "block_sum", + "block_prefix_stage1", + "block_prefix_stage2", + "scatter", +) + + +def _select_backend() -> None: + if hasattr(tf, "initialize"): + backend = getattr(tf, "vulkan", None) + if backend is None: + raise RuntimeError("TensorFrost Vulkan backend is unavailable on this build") + tf.initialize(backend) + + +def _format_rate(value: float, unit: str) -> str: + if value <= 0.0 or not math.isfinite(value): + return f"0 {unit}" + + prefixes = ("", "K", "M", "G", "T", "P") + magnitude = 0 + while value >= 1000.0 and magnitude < len(prefixes) - 1: + value /= 1000.0 + magnitude += 1 + + return f"{value:6.2f} {prefixes[magnitude]}{unit}" + + +def _format_stage_name(name: str) -> str: + return name.replace("_", " ").title() + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Interactive counting sort demo running continuously on the Vulkan backend.", + ) + parser.add_argument("--size", type=int, default=1 << 22, help="Number of keys to sort") + parser.add_argument( + "--max-value", + type=int, + default=1 << 20, + help="Exclusive upper bound for the generated key values", + ) + parser.add_argument("--window-width", type=int, default=960, help="Window width in pixels") + parser.add_argument("--window-height", type=int, default=540, help="Window height in pixels") + parser.add_argument("--font-scale", type=float, default=1.6, help="Global ImGui font scale") + parser.add_argument("--history", type=int, default=240, help="Number of samples stored for metrics history") + parser.add_argument("--seed", type=int, default=1337, help="Seed for the random data generator") + args = parser.parse_args() + + _select_backend() + + count = max(0, int(args.size)) + max_value = max(1, int(args.max_value)) + window_width = max(320, int(args.window_width)) + window_height = max(240, int(args.window_height)) + history_length = max(1, int(args.history)) + font_scale = max(0.1, float(args.font_scale)) if args.font_scale > 0 else 0.0 + + rng = np.random.default_rng(int(args.seed)) + keys = ( + rng.integers(0, max_value, size=count, dtype=np.uint32) + if count + else np.empty(0, dtype=np.uint32) + ) + values = ( + rng.integers(0, 1 << 31, size=count, dtype=np.uint32) + if count + else None + ) + + frame_times = deque(maxlen=history_length) + sort_times = deque(maxlen=history_length) + stage_totals_overall = {name: 0.0 for name in _STAGE_NAMES} + + validated = False + validation_ok = False + validation_message = "" + + sort_count = 0 + total_kernel_time = 0.0 + + sorter = CountingSort(max_value=max_value) + + window_title = f"TensorFrost Counting Sort ({count:,} elements)" + window = tf.createWindow(window_width, window_height, window_title) + + if font_scale > 0.0 and window is not None: + window.imgui_scale_all_sizes(font_scale) + window.imgui_set_font_global_scale(font_scale) + + start_time = time.perf_counter() + last_frame_time = start_time + + while window is not None and window.isOpen(): + now = time.perf_counter() + dt = now - last_frame_time + last_frame_time = now + + if dt > 0.0: + frame_times.append(dt) + frame_time_total = sum(frame_times) + fps = (len(frame_times) / frame_time_total) if frame_times and frame_time_total > 0.0 else 0.0 + + do_validate = not validated + return_arrays = do_validate + _keys_out, _vals_out = sorter.sort( + keys, + values, + collect_stage_timings=True, + validate=do_validate, + return_arrays=return_arrays, + ) + + stage_timings = sorter.last_stage_timings or {} + total_pass_time = float(stage_timings.get("total_pass", 0.0)) + stage_sum_time = float(sum(v for k, v in stage_timings.items() if k != "total_pass")) + kernel_time = total_pass_time if total_pass_time > 0.0 else stage_sum_time + + if sort_times.maxlen == len(sort_times): + sort_times.popleft() + sort_times.append(kernel_time) + + sort_count += 1 + total_kernel_time += kernel_time + for name in _STAGE_NAMES: + stage_totals_overall[name] += stage_timings.get(name, 0.0) + + if do_validate: + errors = int(getattr(sorter, "last_validation_errors", 0) or 0) + gpu_validation_ok = (errors == 0) + + cpu_validation_ok = True + cpu_messages = [] + if return_arrays and count: + reference_keys = np.sort(keys, kind="stable") + if not np.array_equal(_keys_out, reference_keys): + cpu_validation_ok = False + cpu_messages.append("key mismatch") + + validation_ok = gpu_validation_ok and cpu_validation_ok + if validation_ok: + validation_message = "GPU + NumPy validation passed." + else: + failure_reasons = [] + if not gpu_validation_ok: + failure_reasons.append(f"GPU reported {errors} out-of-order pairs") + if not cpu_validation_ok: + reason = ", ".join(cpu_messages) if cpu_messages else "NumPy comparison failed" + failure_reasons.append(reason) + validation_message = "Validation failed: " + "; ".join(failure_reasons) + + validated = True + + window_avg_sort = (sum(sort_times) / len(sort_times)) if sort_times else 0.0 + last_sort_ms = kernel_time * 1000.0 + avg_sort_ms = window_avg_sort * 1000.0 + + last_sort_rate = (1.0 / kernel_time) if kernel_time > 0.0 else 0.0 + window_sort_rate = (1.0 / window_avg_sort) if window_avg_sort > 0.0 else 0.0 + overall_avg_sort = (total_kernel_time / sort_count) if sort_count else 0.0 + overall_sort_rate = (1.0 / overall_avg_sort) if overall_avg_sort > 0.0 else 0.0 + + last_elements_per_sec = (count / kernel_time) if count and kernel_time > 0.0 else 0.0 + window_elements_per_sec = (count / window_avg_sort) if count and window_avg_sort > 0.0 else 0.0 + + history_data = np.array(sort_times, dtype=np.float32) * 1000.0 + + visible, _ = window.imgui_begin("Counting Sort Performance", open=None, flags=0) + if visible: + window.imgui_text(f"Elements: {count:,}") + window.imgui_text(f"Max value: {max_value:,}") + window.imgui_separator() + window.imgui_text(f"Frame time: {dt * 1000.0:7.3f} ms") + window.imgui_text(f"Average FPS: {fps:6.1f}") + window.imgui_separator() + + if sort_count: + window.imgui_text(f"Last kernel time: {last_sort_ms:7.3f} ms") + if total_pass_time > 0.0: + window.imgui_text(f"Total pass time (last): {total_pass_time * 1000.0:7.3f} ms") + if window_avg_sort > 0.0: + window.imgui_text(f"Window avg kernel: {avg_sort_ms:7.3f} ms") + if overall_avg_sort > 0.0: + window.imgui_text(f"Overall avg kernel: {overall_avg_sort * 1000.0:7.3f} ms") + + window.imgui_spacing() + window.imgui_text(f"Sorts/sec (last): {last_sort_rate:8.2f}") + window.imgui_text(f"Sorts/sec (window): {window_sort_rate:8.2f}") + window.imgui_text(f"Sorts/sec (overall): {overall_sort_rate:8.2f}") + + if count: + window.imgui_spacing() + window.imgui_text( + f"Elements/sec (last): {_format_rate(last_elements_per_sec, ' elements/s')}" + ) + window.imgui_text( + f"Elements/sec (window): {_format_rate(window_elements_per_sec, ' elements/s')}" + ) + + window.imgui_spacing() + window.imgui_text(f"Total sorts: {sort_count:,}") + else: + if count: + window.imgui_text("Waiting for first GPU sort...") + else: + window.imgui_text("No elements to sort.") + + if validated: + color = (0.20, 0.80, 0.40, 1.0) if validation_ok else (0.92, 0.35, 0.32, 1.0) + window.imgui_text_colored(color, validation_message) + elif count: + window.imgui_text("Validating against CPU reference...") + + if history_data.size: + max_plot = float(np.max(history_data)) if history_data.size else 1.0 + max_plot = max(1.0, max_plot * 1.1) + window.imgui_plot_lines( + "Kernel time history (ms)", + history_data, + values_offset=0, + overlay_text=f"{history_data.size} sample history", + scale_min=0.0, + scale_max=max_plot, + graph_size=(0.0, 140.0), + stride=4, + ) + + if stage_timings: + window.imgui_separator() + window.imgui_text("Kernel stages (last run):") + for name, duration in sorted(stage_timings.items(), key=lambda item: item[1], reverse=True): + if name == "total_pass": + continue + window.imgui_text(f"{_format_stage_name(name)}: {duration * 1000.0:7.3f} ms") + + window.imgui_spacing() + window.imgui_text("Kernel stages (overall avg):") + for name in sorted(stage_totals_overall.keys()): + avg_duration = (stage_totals_overall[name] / sort_count) if sort_count else 0.0 + window.imgui_text(f"{_format_stage_name(name)}: {avg_duration * 1000.0:7.3f} ms") + + window.imgui_end() + window.present() + + +if __name__ == "__main__": + main() diff --git a/examples/counting_sort/shaders/block_prefix_stage2.slang b/examples/counting_sort/shaders/block_prefix_stage2.slang new file mode 100644 index 00000000..ca4d7275 --- /dev/null +++ b/examples/counting_sort/shaders/block_prefix_stage2.slang @@ -0,0 +1,24 @@ +struct BlockPrefixStage2Params { + uint groupCount; +}; + +[[vk::push_constant]] +BlockPrefixStage2Params gParams; + +[[vk::binding(0,0)]] StructuredBuffer BlockGroupTotals : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer BlockGroupPrefix : register(u1, space0); + +[numthreads(64, 1, 1)] +void csBlockPrefixStage2(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + if (dispatchThreadID.x != 0) + return; + + uint groupCount = gParams.groupCount; + uint prefix = 0; + for (uint group = 0; group < groupCount; ++group) + { + BlockGroupPrefix[group] = prefix; + prefix += BlockGroupTotals[group]; + } +} diff --git a/examples/counting_sort/shaders/block_sum.slang b/examples/counting_sort/shaders/block_sum.slang new file mode 100644 index 00000000..c3312fc2 --- /dev/null +++ b/examples/counting_sort/shaders/block_sum.slang @@ -0,0 +1,37 @@ +struct SegmentScanParams { + uint segmentCount; + uint segmentSpan; + uint inputLimit; +}; + +[[vk::push_constant]] +SegmentScanParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer InputData : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer PrefixOutput : register(u1, space0); +[[vk::binding(2,0)]] RWStructuredBuffer SegmentTotals : register(u2, space0); + +[numthreads(64, 1, 1)] +void csSegmentScan(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint segment = dispatchThreadID.x; + if (segment >= gParams.segmentCount) + return; + + uint start = segment * gParams.segmentSpan; + if (start >= gParams.inputLimit) + { + SegmentTotals[segment] = 0; + return; + } + + uint endIndex = min(start + gParams.segmentSpan, gParams.inputLimit); + uint prefix = 0; + for (uint idx = start; idx < endIndex; ++idx) + { + PrefixOutput[idx] = prefix; + prefix += InputData[idx]; + } + + SegmentTotals[segment] = prefix; +} diff --git a/examples/counting_sort/shaders/histogram_rank.slang b/examples/counting_sort/shaders/histogram_rank.slang new file mode 100644 index 00000000..a20f11cf --- /dev/null +++ b/examples/counting_sort/shaders/histogram_rank.slang @@ -0,0 +1,29 @@ +static const uint GROUP_SIZE = CS_GROUP_SIZE; + +struct HistogramParams { + uint elementCount; + uint valueCount; +}; + +[[vk::push_constant]] +HistogramParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer KeysIn : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer Histogram : register(u1, space0); +[[vk::binding(2,0)]] RWStructuredBuffer LocalRank : register(u2, space0); + +[numthreads(GROUP_SIZE, 1, 1)] +void csHistogramAndRank(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint index = dispatchThreadID.x; + if (index >= gParams.elementCount) + return; + + uint value = KeysIn[index]; + if (value >= gParams.valueCount) + return; + + uint rank = 0; + InterlockedAdd(Histogram[value], 1u, rank); + LocalRank[index] = rank; +} diff --git a/examples/counting_sort/shaders/scatter.slang b/examples/counting_sort/shaders/scatter.slang new file mode 100644 index 00000000..ec1bc039 --- /dev/null +++ b/examples/counting_sort/shaders/scatter.slang @@ -0,0 +1,47 @@ +static const uint GROUP_SIZE = CS_GROUP_SIZE; + +struct ScatterParams { + uint elementCount; + uint valueCount; + uint blockSpan; + uint blocksPerGroup; + uint hasValues; +}; + +[[vk::push_constant]] +ScatterParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer KeysIn : register(t0, space0); +[[vk::binding(1,0)]] StructuredBuffer ValuesIn : register(t1, space0); +[[vk::binding(2,0)]] StructuredBuffer LocalRank : register(t2, space0); +[[vk::binding(3,0)]] StructuredBuffer LocalPrefix : register(t3, space0); +[[vk::binding(4,0)]] StructuredBuffer BlockPrefixStage1 : register(t4, space0); +[[vk::binding(5,0)]] StructuredBuffer BlockGroupPrefix : register(t5, space0); +[[vk::binding(6,0)]] RWStructuredBuffer KeysOut : register(u6, space0); +[[vk::binding(7,0)]] RWStructuredBuffer ValuesOut : register(u7, space0); + +[numthreads(GROUP_SIZE, 1, 1)] +void csScatter(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint index = dispatchThreadID.x; + uint elementCount = gParams.elementCount; + if (index >= elementCount) + return; + + uint value = KeysIn[index]; + if (value >= gParams.valueCount) + return; + + uint rank = LocalRank[index]; + uint block = value / gParams.blockSpan; + uint group = block / max(gParams.blocksPerGroup, 1u); + uint localOffset = LocalPrefix[value]; + uint blockOffset = BlockPrefixStage1[block] + BlockGroupPrefix[group]; + uint destination = blockOffset + localOffset + rank; + + KeysOut[destination] = value; + if (gParams.hasValues != 0) + { + ValuesOut[destination] = ValuesIn[index]; + } +} diff --git a/examples/counting_sort/shaders/validate_sorted.slang b/examples/counting_sort/shaders/validate_sorted.slang new file mode 100644 index 00000000..524e5bf7 --- /dev/null +++ b/examples/counting_sort/shaders/validate_sorted.slang @@ -0,0 +1,27 @@ +static const uint GROUP_SIZE = CS_GROUP_SIZE; + +struct ValidateParams { + uint elementCount; +}; + +[[vk::push_constant]] +ValidateParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer Keys : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer Errors : register(u1, space0); + +[numthreads(GROUP_SIZE, 1, 1)] +void csValidate(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint index = dispatchThreadID.x; + uint n = gParams.elementCount; + if (n < 2 || index >= n - 1) + return; + + uint a = Keys[index]; + uint b = Keys[index + 1u]; + if (a > b) + { + InterlockedAdd(Errors[0], 1u); + } +} diff --git a/examples/counting_sort/sort.py b/examples/counting_sort/sort.py new file mode 100644 index 00000000..1b8c8b0c --- /dev/null +++ b/examples/counting_sort/sort.py @@ -0,0 +1,362 @@ +from __future__ import annotations + +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Optional, Tuple + +import numpy as np + +import TensorFrost as tf + +__all__ = ["CountingSort", "counting_sort"] + + +def _dispatch_groups(work_items: int, threads_per_group: int) -> int: + if work_items <= 0: + return 0 + return (work_items + threads_per_group - 1) // threads_per_group + + +def _prepare_keys(keys: np.ndarray) -> np.ndarray: + array = np.asarray(keys) + if array.ndim != 1: + raise ValueError("counting_sort expects a 1D array of keys") + if array.dtype != np.uint32: + array = array.astype(np.uint32, copy=False) + return array + + +def _prepare_values(values: np.ndarray) -> Tuple[np.ndarray, np.dtype]: + array = np.asarray(values) + if array.ndim != 1: + raise ValueError("counting_sort expects a 1D array of values when provided") + dtype = array.dtype + if dtype not in (np.uint32, np.int32, np.float32): + raise TypeError("values must have dtype uint32, int32, or float32") + return array, dtype + + +_SHADER_DIR = Path(__file__).resolve().parent / "shaders" + + +def _load_shader_source(filename: str) -> str: + return (_SHADER_DIR / filename).read_text(encoding="utf-8") + + +@dataclass(frozen=True) +class _SorterKey: + max_value: int + block_span: int + group_size: int + blocks_per_group: int + + +class CountingSort: + """GPU counting sort built on top of TensorFrost compute and Slang shaders.""" + + def __init__( + self, + *, + max_value: int, + block_span: int = 256, + group_size: int = 128, + blocks_per_group: int = 256, + ) -> None: + if max_value <= 0: + raise ValueError("max_value must be positive") + if block_span <= 0: + raise ValueError("block_span must be positive") + if group_size <= 0 or group_size > 1024: + raise ValueError("group_size must be within (0, 1024]") + if blocks_per_group <= 0: + raise ValueError("blocks_per_group must be positive") + + self.value_count = int(max_value) + self.block_span = int(block_span) + self.group_size = int(group_size) + self.blocks_per_group = int(blocks_per_group) + self.block_count = max((self.value_count + self.block_span - 1) // self.block_span, 1) + self.block_group_count = max((self.block_count + self.blocks_per_group - 1) // self.blocks_per_group, 1) + + self.last_stage_timings: Optional[Dict[str, float]] = None + self.last_validation_errors: Optional[int] = None + + def inject_defines(filename: str) -> str: + defines = [ + f"#define CS_GROUP_SIZE {self.group_size}u", + f"#define CS_BLOCK_SPAN {self.block_span}u", + ] + return "\n".join(defines) + "\n" + _load_shader_source(filename) + + self._dummy_values_buffer = tf.createBuffer(1, 4, False) + + self._histogram_program = tf.createComputeProgramFromSlang( + "count_histogram_rank", + inject_defines("histogram_rank.slang"), + "csHistogramAndRank", + ro_count=1, + rw_count=2, + push_constant_size=8, + ) + self._segment_scan_program = tf.createComputeProgramFromSlang( + "count_segment_scan", + inject_defines("block_sum.slang"), + "csSegmentScan", + ro_count=1, + rw_count=2, + push_constant_size=12, + ) + self._block_prefix_stage2_program = tf.createComputeProgramFromSlang( + "count_block_prefix_stage2", + inject_defines("block_prefix_stage2.slang"), + "csBlockPrefixStage2", + ro_count=1, + rw_count=1, + push_constant_size=4, + ) + self._scatter_program = tf.createComputeProgramFromSlang( + "count_scatter", + inject_defines("scatter.slang"), + "csScatter", + ro_count=6, + rw_count=2, + push_constant_size=20, + ) + self._validate_program = tf.createComputeProgramFromSlang( + "count_validate", + inject_defines("validate_sorted.slang"), + "csValidate", + ro_count=1, + rw_count=1, + push_constant_size=4, + ) + + def sort( + self, + keys: np.ndarray, + values: Optional[np.ndarray] = None, + *, + collect_stage_timings: bool = False, + validate: bool = False, + return_arrays: bool = False, + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + keys_array = _prepare_keys(keys) + element_count = int(keys_array.shape[0]) + + if values is not None: + values_array, values_dtype = _prepare_values(values) + if values_array.shape[0] != element_count: + raise ValueError("values must have the same length as keys") + else: + values_array = None + values_dtype = None + + if element_count == 0: + self.last_stage_timings = {} if collect_stage_timings else None + if validate: + self.last_validation_errors = 0 + if values_array is None: + return keys_array.copy(), None + return keys_array.copy(), values_array.copy() + + if np.max(keys_array) >= self.value_count: + raise ValueError(f"keys must be within [0, {self.value_count}) for this sorter") + + if collect_stage_timings: + stage_totals: Dict[str, float] = { + "histogram": 0.0, + "block_sum": 0.0, + "block_prefix_stage1": 0.0, + "block_prefix_stage2": 0.0, + "scatter": 0.0, + } + else: + stage_totals = {} + + key_in_buffer = tf.createBuffer(max(element_count, 1), 4, False) + key_in_buffer.setData(keys_array) + key_out_buffer = tf.createBuffer(max(element_count, 1), 4, False) + + if values_array is not None: + val_in_buffer = tf.createBuffer(max(element_count, 1), 4, False) + val_in_buffer.setData(values_array) + val_out_buffer = tf.createBuffer(max(element_count, 1), 4, False) + else: + dummy = self._dummy_values_buffer + val_in_buffer = dummy + val_out_buffer = dummy + + histogram_buffer = tf.createBuffer(max(self.value_count, 1), 4, False) + z_hist = np.zeros(self.value_count, dtype=np.uint32) + histogram_buffer.setData(z_hist) + + local_rank_buffer = tf.createBuffer(max(element_count, 1), 4, False) + local_prefix_buffer = tf.createBuffer(max(self.value_count, 1), 4, False) + block_totals_buffer = tf.createBuffer(max(self.block_count, 1), 4, False) + block_prefix_stage1_buffer = tf.createBuffer(max(self.block_count, 1), 4, False) + block_group_totals_buffer = tf.createBuffer(max(self.block_group_count, 1), 4, False) + block_group_prefix_buffer = tf.createBuffer(max(self.block_group_count, 1), 4, False) + + hist_params = np.zeros(2, dtype=np.uint32) + hist_params[0] = np.uint32(element_count) + hist_params[1] = np.uint32(self.value_count) + + segment_params_hist = np.zeros(3, dtype=np.uint32) + segment_params_hist[0] = np.uint32(self.block_count) + segment_params_hist[1] = np.uint32(self.block_span) + segment_params_hist[2] = np.uint32(self.value_count) + + stage1_params = np.zeros(3, dtype=np.uint32) + stage1_params[0] = np.uint32(self.block_group_count) + stage1_params[1] = np.uint32(self.blocks_per_group) + stage1_params[2] = np.uint32(self.block_count) + + stage2_params = np.zeros(1, dtype=np.uint32) + stage2_params[0] = np.uint32(self.block_group_count) + + scatter_params = np.zeros(5, dtype=np.uint32) + scatter_params[0] = np.uint32(element_count) + scatter_params[1] = np.uint32(self.value_count) + scatter_params[2] = np.uint32(self.block_span) + scatter_params[3] = np.uint32(self.blocks_per_group) + scatter_params[4] = np.uint32(1 if values_array is not None else 0) + + histogram_groups = _dispatch_groups(element_count, self.group_size) + block_sum_groups = _dispatch_groups(self.block_count, 64) + block_prefix_stage1_groups = _dispatch_groups(self.block_group_count, 64) + block_prefix_stage2_groups = _dispatch_groups(self.block_group_count, 64) + scatter_groups = _dispatch_groups(element_count, self.group_size) + + total_start = time.perf_counter() if collect_stage_timings else None + + start = time.perf_counter() if collect_stage_timings else None + self._histogram_program.run( + [key_in_buffer], + [histogram_buffer, local_rank_buffer], + histogram_groups, + hist_params, + ) + if collect_stage_timings and start is not None: + stage_totals["histogram"] += time.perf_counter() - start + + start = time.perf_counter() if collect_stage_timings else None + self._segment_scan_program.run( + [histogram_buffer], + [local_prefix_buffer, block_totals_buffer], + block_sum_groups, + segment_params_hist, + ) + if collect_stage_timings and start is not None: + stage_totals["block_sum"] += time.perf_counter() - start + + start = time.perf_counter() if collect_stage_timings else None + self._segment_scan_program.run( + [block_totals_buffer], + [block_prefix_stage1_buffer, block_group_totals_buffer], + block_prefix_stage1_groups, + stage1_params, + ) + if collect_stage_timings and start is not None: + stage_totals["block_prefix_stage1"] += time.perf_counter() - start + + start = time.perf_counter() if collect_stage_timings else None + self._block_prefix_stage2_program.run( + [block_group_totals_buffer], + [block_group_prefix_buffer], + block_prefix_stage2_groups, + stage2_params, + ) + if collect_stage_timings and start is not None: + stage_totals["block_prefix_stage2"] += time.perf_counter() - start + + start = time.perf_counter() if collect_stage_timings else None + self._scatter_program.run( + [ + key_in_buffer, + val_in_buffer, + local_rank_buffer, + local_prefix_buffer, + block_prefix_stage1_buffer, + block_group_prefix_buffer, + ], + [key_out_buffer, val_out_buffer], + scatter_groups, + scatter_params, + ) + if collect_stage_timings and start is not None: + stage_totals["scatter"] += time.perf_counter() - start + + if collect_stage_timings and total_start is not None: + stage_totals["total_pass"] = time.perf_counter() - total_start + + if validate: + validate_params = np.zeros(1, dtype=np.uint32) + validate_params[0] = np.uint32(element_count) + + error_buf = tf.createBuffer(1, 4, False) + error_zero = np.zeros(1, dtype=np.uint32) + error_buf.setData(error_zero) + + self._validate_program.run( + [key_out_buffer], + [error_buf], + _dispatch_groups(element_count, self.group_size), + validate_params, + ) + + error_count = int(error_buf.getData(np.dtype(np.uint32), 1)[0]) + self.last_validation_errors = error_count + else: + self.last_validation_errors = None + + if return_arrays: + sorted_keys = key_out_buffer.getData(np.dtype(np.uint32), element_count) + if values_array is not None and values_dtype is not None: + sorted_values = val_out_buffer.getData(values_dtype, element_count) + else: + sorted_values = None + else: + sorted_keys = np.empty(0, dtype=np.uint32) + sorted_values = None if values_array is None or values_dtype is None else np.empty(0, dtype=values_dtype) + + self.last_stage_timings = stage_totals if collect_stage_timings else None + return sorted_keys, sorted_values + + +_SORTER_CACHE: Dict[_SorterKey, CountingSort] = {} + + +def _get_sorter(max_value: int, block_span: int, group_size: int, blocks_per_group: int) -> CountingSort: + key = _SorterKey(max_value, block_span, group_size, blocks_per_group) + sorter = _SORTER_CACHE.get(key) + if sorter is None: + sorter = CountingSort( + max_value=max_value, + block_span=block_span, + group_size=group_size, + blocks_per_group=blocks_per_group, + ) + _SORTER_CACHE[key] = sorter + return sorter + + +def counting_sort( + keys: np.ndarray, + values: Optional[np.ndarray] = None, + *, + max_value: int, + block_span: int = 256, + group_size: int = 128, + blocks_per_group: int = 256, +) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """Convenience wrapper around :class:`CountingSort`.""" + + if max_value is None: + raise ValueError("max_value must be provided for counting_sort") + + sorter = _get_sorter(int(max_value), int(block_span), int(group_size), int(blocks_per_group)) + sorted_keys, sorted_values = sorter.sort(keys, values) + if values is None: + return sorted_keys + return sorted_keys, sorted_values