diff --git a/.gitignore b/.gitignore index 718829be..b64dee42 100644 --- a/.gitignore +++ b/.gitignore @@ -56,3 +56,6 @@ imgui.ini *.npz /cmake-build-* *.pyc +/.cmake +/CMakeFiles +/.debug diff --git a/.gitmodules b/.gitmodules index be28e636..fbd707c9 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,9 +4,9 @@ [submodule "external/glfw"] path = external/glfw url = https://github.com/glfw/glfw.git -[submodule "external/glad"] - path = external/glad - url = https://github.com/Dav1dde/glad.git [submodule "external/imgui"] path = external/imgui url = https://github.com/ocornut/imgui.git +[submodule "Python/external/shaderc"] + path = Python/external/shaderc + url = https://github.com/google/shaderc diff --git a/.idea/TensorFrost.iml b/.idea/TensorFrost.iml index 1837251a..83f477b5 100644 --- a/.idea/TensorFrost.iml +++ b/.idea/TensorFrost.iml @@ -2,7 +2,7 @@ - + \ No newline at end of file diff --git a/.idea/copilot.data.migration.agent.xml b/.idea/copilot.data.migration.agent.xml new file mode 100644 index 00000000..4ea72a91 --- /dev/null +++ b/.idea/copilot.data.migration.agent.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/.idea/editor.xml b/.idea/editor.xml index 6692539a..4d514864 100644 --- a/.idea/editor.xml +++ b/.idea/editor.xml @@ -1,302 +1,60 @@ - \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml index a8974b61..a48f14df 100644 --- a/.idea/vcs.xml +++ b/.idea/vcs.xml @@ -2,7 +2,6 @@ - diff --git a/.run/TensorFrost.run.xml b/.run/TensorFrost.run.xml index 97e80504..61d7af64 100644 --- a/.run/TensorFrost.run.xml +++ b/.run/TensorFrost.run.xml @@ -1,5 +1,5 @@ - + diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..f15737e5 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,10 @@ +# Agent Guide + +Follow these expectations whenever you work in this repository: + +1. **Full rebuild & virtual environment** — Run `setup_python_env.cmd` from the repo root. It configures the Python virtual environment and performs a clean rebuild so you start from a consistent state. +2. **Partial rebuilds** — Use CMake for incremental builds. Invoke the appropriate CMake build command (for example, `cmake --build --target `) to rebuild only what you need. +3. **C++ changes** — Any edits under `TensorFrost/` or other C++ sources require a rebuild before the changes take effect. +4. **Python script changes** - After edits of python script, you should run them to make sure they work correctly. No recompilation needed. +5. **API validation** — After modifying functionality, run the relevant tests in the `tests/` folder to confirm the Python API still behaves as expected. +6. **Scenario validation** — Run the sample programs in the `examples/` folder to make sure the updated stack handles more complex end-to-end flows. diff --git a/CMakeLists.txt b/CMakeLists.txt index 3c927cf8..09010907 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,10 +34,11 @@ set(GLFW_BUILD_TESTS OFF CACHE BOOL "" FORCE) set(GLFW_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE) set(PYBIND11_FINDPYTHON ON) +find_package(Vulkan REQUIRED) + add_subdirectory(external/pybind11) add_subdirectory(external/glfw) -add_subdirectory(external/glad/cmake) add_subdirectory(TensorFrost) add_subdirectory(examples) -set_property(DIRECTORY ${CMAKE_SOURCE_DIR} PROPERTY VS_STARTUP_PROJECT TensorFrost) \ No newline at end of file +set_property(DIRECTORY ${CMAKE_SOURCE_DIR} PROPERTY VS_STARTUP_PROJECT TensorFrost) diff --git a/ProtoIR.txt b/ProtoIR.txt new file mode 100644 index 00000000..ac6519dc --- /dev/null +++ b/ProtoIR.txt @@ -0,0 +1,16 @@ +Node = [name, arguments, attributes] + + +n = input_dim(attributes{type=int32, input_index=0, dim_index=0}) +a = input(args{shape=[n]}, attributes{type=float32, input_index=0}) +b = sin(args{input=[a], shape=[n]}, attributes{type=float32}) +c = load(args{input=[b], indices=[0]}, attributes{type=float32}) +d = average(args{input=[b]}, attributes{type=float32}) +res = div(args{input=[d,c]}, attributes{type=float32}) +ids = parallel(args{shape=[n, n]}, attributes{type=tuple}) { + i = unpack_tuple(args{input=[ids,0]}, attributes{type=int32}) + j = unpack_tuple(args{input=[ids,0]}, attributes{type=int32}) + a0 = load(args{input=[b], indices=[i]}, attributes{type=float32}) + a1 = load(args{input=[b], indices=[j]}, attributes{type=float32}) + outer = mul(args{input=[a0, a1]}, attributes{type=float32, output_index=0}) +} \ No newline at end of file diff --git a/Python/TensorFrost/__init__.py b/Python/TensorFrost/__init__.py index 3fddb9ae..ce11ec6e 100644 --- a/Python/TensorFrost/__init__.py +++ b/Python/TensorFrost/__init__.py @@ -4,7 +4,6 @@ from . import regularizers from . import clipping from . import random -from . import sort from .default import * # def compile(func): diff --git a/Python/TensorFrost/clipping.py b/Python/TensorFrost/clipping.py index b343917e..7c79851d 100644 --- a/Python/TensorFrost/clipping.py +++ b/Python/TensorFrost/clipping.py @@ -1,5 +1,5 @@ -from .optimizers import * - -clamp = ModuleOptimizer.ClippingType.Clamp -norm = ModuleOptimizer.ClippingType.Norm -none = ModuleOptimizer.ClippingType.None_ \ No newline at end of file +# from .optimizers import * +# +# clamp = ModuleOptimizer.ClippingType.Clamp +# norm = ModuleOptimizer.ClippingType.Norm +# none = ModuleOptimizer.ClippingType.None_ \ No newline at end of file diff --git a/Python/TensorFrost/default.py b/Python/TensorFrost/default.py index 41853504..a015108c 100644 --- a/Python/TensorFrost/default.py +++ b/Python/TensorFrost/default.py @@ -1,14 +1,14 @@ -from . import TensorFrost as tf - -def zeros_like(tensor): - return tf.zeros(tensor.shape, tensor.type) - -def eye(n): - i, j = tf.indices([n, n]) - return tf.select(i == j, 1.0, 0.0) - -def eye_like(tensor): - return eye(tensor.shape[0]) - -def ones_like(tensor): - return tf.ones(tensor.shape, tensor.type) \ No newline at end of file +# from . import TensorFrost as tf +# +# def zeros_like(tensor): +# return tf.zeros(tensor.shape, tensor.type) +# +# def eye(n): +# i, j = tf.indices([n, n]) +# return tf.select(i == j, 1.0, 0.0) +# +# def eye_like(tensor): +# return eye(tensor.shape[0]) +# +# def ones_like(tensor): +# return tf.ones(tensor.shape, tensor.type) \ No newline at end of file diff --git a/Python/TensorFrost/optimizers.py b/Python/TensorFrost/optimizers.py index 731ae81b..21b07ce4 100644 --- a/Python/TensorFrost/optimizers.py +++ b/Python/TensorFrost/optimizers.py @@ -1,219 +1,219 @@ -from . import TensorFrost as tf - -class ModuleOptimizer(tf.Module): - class OptimizerType: - ADAM = 0 - SGD = 1 - RMSProp = 2 - - class RegularizerType: - None_ = 0 - L1 = 1 - L2 = 2 - - class ClippingType: - Clamp = 0 - Norm = 1 - None_ = 2 - - def __init__(self, optimizer_type, regularizer_type, net, params): - super().__init__() - self.optimizer_type = optimizer_type - self.regularizer_type = regularizer_type - self.clipping_type = self.ClippingType.Clamp - self.epsilon = 1e-8 - - # Set passed parameters as attributes - self.net = net - for k, v in params.items(): - setattr(self, k, v) - - # Initialize t - t = tf.Parameter([1], tf.float32, False) # mimic Parameter({1}, TFType::Float, false) - self.t = t - - self.initializeOptimizer(net) - - def set_clipping_type(self, ctype): - self.clipping_type = ctype - - def initializeOptimizer(self, net): - net_params = net.parameters() - requires_grads = net.requires_grads_list() - - if self.optimizer_type == self.OptimizerType.ADAM: - self.initializeParameterArray("m", net_params, requires_grads) - self.initializeParameterArray("v", net_params, requires_grads) - elif self.optimizer_type == self.OptimizerType.SGD: - # No additional parameters needed - pass - elif self.optimizer_type == self.OptimizerType.RMSProp: - self.initializeParameterArray("v", net_params, requires_grads) - - def initializeParameterArray(self, name, net_params, requires_grads): - arr = tf.ParameterArray() - - for i, param in enumerate(net_params): - if not requires_grads[i]: - continue - - new_param = tf.Parameter(param.shape, tf.float32, False) - arr[i] = new_param - - setattr(self, name, arr) - - def assert_parameters(self): - net_params = self.net.parameters() - requires_grads = self.net.requires_grads_list() - self.assertParameterArray("m", net_params, requires_grads) - self.assertParameterArray("v", net_params, requires_grads) - - def gradient_norm(self, grad): - # sum of squares - g = grad * grad - shape = grad.shape - num_dims = len(shape) - for i in range(num_dims): - g = tf.sum(g) - return tf.sqrt(g) - - def assertParameterArray(self, name, net_params, requires_grads): - if hasattr(self, name): - arr = getattr(self, name) - for i, param in enumerate(net_params): - if not requires_grads[i]: - continue - arr_item = arr[i] - arr_item = tf.assert_tensor(arr_item, param.shape, param.type) - arr[i] = arr_item - - def step(self, *args): - # Overloaded step: - # step(X, Y) or step(loss) - if len(args) == 2: - X, Y = args - loss = self.net.loss(X, Y) - self._step(loss) - return loss - elif len(args) == 1: - (loss,) = args - self._step(loss) - else: - raise ValueError("Invalid arguments to step") - - def _step(self, loss): - # Increment t by 1 - self.t = self.t + 1.0 - - net = self.net - net_params = net.parameters() - requires_grads = net.requires_grads_list() - - learning_rate = self.learning_rate - grad_clip = self.grad_clip - has_clip = isinstance(grad_clip, float) and grad_clip > 0.0 - - for i, param in enumerate(net_params): - if not requires_grads[i]: - continue - - grad = tf.grad(loss, param) - if has_clip: - if self.clipping_type == self.ClippingType.Clamp: - grad = tf.clamp(grad, -grad_clip, grad_clip) - elif self.clipping_type == self.ClippingType.Norm: - grad_norm = tf.max(1e-6, self.gradient_norm(grad)) - grad = grad * tf.min(1.0, grad_clip / grad_norm) - - if self.optimizer_type == self.OptimizerType.ADAM: - update = self.adam_update(i, param, grad, self.t, learning_rate) - elif self.optimizer_type == self.OptimizerType.SGD: - update = self.sgd_update(param, grad, learning_rate) - elif self.optimizer_type == self.OptimizerType.RMSProp: - update = self.rmsprop_update(i, param, grad, learning_rate) - else: - raise RuntimeError("Unknown optimizer type") - - # Apply regularization if needed - if self.regularizer_type == self.RegularizerType.L1: - param = param - learning_rate * self.reg * tf.sign(param) - elif self.regularizer_type == self.RegularizerType.L2: - param = param - learning_rate * self.reg * param - - # Update parameter with computed update - param = param - update - net_params[i] = param - - net.update_parameters(net_params) - - def adam_update(self, i, param, grad, t, learning_rate): - beta1 = tf.float(self.beta1) - beta2 = tf.float(self.beta2) - - m = self.m[i] - v = self.v[i] - - m = tf.lerp(grad, m, beta1) - v = tf.lerp(grad * grad, v, beta2) - - # t is a Parameter with shape [1]; get the scalar - t_val = self.t[0] - mhat = m / (1.0 - tf.pow(beta1, t_val)) - vhat = v / (1.0 - tf.pow(beta2, t_val)) - - self.m[i] = m - self.v[i] = v - - return learning_rate * mhat / (tf.sqrt(vhat) + self.epsilon) - - def sgd_update(self, param, grad, learning_rate): - return learning_rate * grad - - def rmsprop_update(self, i, param, grad, learning_rate): - decay = tf.float(self.decay) - - v = self.v[i] - v = tf.lerp(grad * grad, v, decay) - self.v[i] = v - - return (grad * learning_rate) / (tf.sqrt(v) + self.epsilon) - - -def adam(net, reg_type=ModuleOptimizer.RegularizerType.None_, learning_rate=0.001, beta1=0.9, beta2=0.999, clip=0.0, reg=0.0): - return ModuleOptimizer( - ModuleOptimizer.OptimizerType.ADAM, - reg_type, - net, - { - "learning_rate": learning_rate, - "beta1": beta1, - "beta2": beta2, - "grad_clip": clip, - "reg": reg, - } - ) - -def sgd(net, reg_type=ModuleOptimizer.RegularizerType.None_, learning_rate=0.001, clip=0.0, reg=0.0): - return ModuleOptimizer( - ModuleOptimizer.OptimizerType.SGD, - reg_type, - net, - { - "learning_rate": learning_rate, - "grad_clip": clip, - "reg": reg, - } - ) - -def rmsprop(net, reg_type=ModuleOptimizer.RegularizerType.None_, learning_rate=0.001, decay=0.9, clip=0.0, reg=0.0): - return ModuleOptimizer( - ModuleOptimizer.OptimizerType.RMSProp, - reg_type, - net, - { - "learning_rate": learning_rate, - "decay": decay, - "grad_clip": clip, - "reg": reg, - } - ) \ No newline at end of file +# from . import TensorFrost as tf +# +# class ModuleOptimizer(tf.Module): +# class OptimizerType: +# ADAM = 0 +# SGD = 1 +# RMSProp = 2 +# +# class RegularizerType: +# None_ = 0 +# L1 = 1 +# L2 = 2 +# +# class ClippingType: +# Clamp = 0 +# Norm = 1 +# None_ = 2 +# +# def __init__(self, optimizer_type, regularizer_type, net, params): +# super().__init__() +# self.optimizer_type = optimizer_type +# self.regularizer_type = regularizer_type +# self.clipping_type = self.ClippingType.Clamp +# self.epsilon = 1e-8 +# +# # Set passed parameters as attributes +# self.net = net +# for k, v in params.items(): +# setattr(self, k, v) +# +# # Initialize t +# t = tf.Parameter([1], tf.float32, False) # mimic Parameter({1}, TFType::Float, false) +# self.t = t +# +# self.initializeOptimizer(net) +# +# def set_clipping_type(self, ctype): +# self.clipping_type = ctype +# +# def initializeOptimizer(self, net): +# net_params = net.parameters() +# requires_grads = net.requires_grads_list() +# +# if self.optimizer_type == self.OptimizerType.ADAM: +# self.initializeParameterArray("m", net_params, requires_grads) +# self.initializeParameterArray("v", net_params, requires_grads) +# elif self.optimizer_type == self.OptimizerType.SGD: +# # No additional parameters needed +# pass +# elif self.optimizer_type == self.OptimizerType.RMSProp: +# self.initializeParameterArray("v", net_params, requires_grads) +# +# def initializeParameterArray(self, name, net_params, requires_grads): +# arr = tf.ParameterArray() +# +# for i, param in enumerate(net_params): +# if not requires_grads[i]: +# continue +# +# new_param = tf.Parameter(param.shape, tf.float32, False) +# arr[i] = new_param +# +# setattr(self, name, arr) +# +# def assert_parameters(self): +# net_params = self.net.parameters() +# requires_grads = self.net.requires_grads_list() +# self.assertParameterArray("m", net_params, requires_grads) +# self.assertParameterArray("v", net_params, requires_grads) +# +# def gradient_norm(self, grad): +# # sum of squares +# g = grad * grad +# shape = grad.shape +# num_dims = len(shape) +# for i in range(num_dims): +# g = tf.sum(g) +# return tf.sqrt(g) +# +# def assertParameterArray(self, name, net_params, requires_grads): +# if hasattr(self, name): +# arr = getattr(self, name) +# for i, param in enumerate(net_params): +# if not requires_grads[i]: +# continue +# arr_item = arr[i] +# arr_item = tf.assert_tensor(arr_item, param.shape, param.type) +# arr[i] = arr_item +# +# def step(self, *args): +# # Overloaded step: +# # step(X, Y) or step(loss) +# if len(args) == 2: +# X, Y = args +# loss = self.net.loss(X, Y) +# self._step(loss) +# return loss +# elif len(args) == 1: +# (loss,) = args +# self._step(loss) +# else: +# raise ValueError("Invalid arguments to step") +# +# def _step(self, loss): +# # Increment t by 1 +# self.t = self.t + 1.0 +# +# net = self.net +# net_params = net.parameters() +# requires_grads = net.requires_grads_list() +# +# learning_rate = self.learning_rate +# grad_clip = self.grad_clip +# has_clip = isinstance(grad_clip, float) and grad_clip > 0.0 +# +# for i, param in enumerate(net_params): +# if not requires_grads[i]: +# continue +# +# grad = tf.grad(loss, param) +# if has_clip: +# if self.clipping_type == self.ClippingType.Clamp: +# grad = tf.clamp(grad, -grad_clip, grad_clip) +# elif self.clipping_type == self.ClippingType.Norm: +# grad_norm = tf.max(1e-6, self.gradient_norm(grad)) +# grad = grad * tf.min(1.0, grad_clip / grad_norm) +# +# if self.optimizer_type == self.OptimizerType.ADAM: +# update = self.adam_update(i, param, grad, self.t, learning_rate) +# elif self.optimizer_type == self.OptimizerType.SGD: +# update = self.sgd_update(param, grad, learning_rate) +# elif self.optimizer_type == self.OptimizerType.RMSProp: +# update = self.rmsprop_update(i, param, grad, learning_rate) +# else: +# raise RuntimeError("Unknown optimizer type") +# +# # Apply regularization if needed +# if self.regularizer_type == self.RegularizerType.L1: +# param = param - learning_rate * self.reg * tf.sign(param) +# elif self.regularizer_type == self.RegularizerType.L2: +# param = param - learning_rate * self.reg * param +# +# # Update parameter with computed update +# param = param - update +# net_params[i] = param +# +# net.update_parameters(net_params) +# +# def adam_update(self, i, param, grad, t, learning_rate): +# beta1 = tf.float(self.beta1) +# beta2 = tf.float(self.beta2) +# +# m = self.m[i] +# v = self.v[i] +# +# m = tf.lerp(grad, m, beta1) +# v = tf.lerp(grad * grad, v, beta2) +# +# # t is a Parameter with shape [1]; get the scalar +# t_val = self.t[0] +# mhat = m / (1.0 - tf.pow(beta1, t_val)) +# vhat = v / (1.0 - tf.pow(beta2, t_val)) +# +# self.m[i] = m +# self.v[i] = v +# +# return learning_rate * mhat / (tf.sqrt(vhat) + self.epsilon) +# +# def sgd_update(self, param, grad, learning_rate): +# return learning_rate * grad +# +# def rmsprop_update(self, i, param, grad, learning_rate): +# decay = tf.float(self.decay) +# +# v = self.v[i] +# v = tf.lerp(grad * grad, v, decay) +# self.v[i] = v +# +# return (grad * learning_rate) / (tf.sqrt(v) + self.epsilon) +# +# +# def adam(net, reg_type=ModuleOptimizer.RegularizerType.None_, learning_rate=0.001, beta1=0.9, beta2=0.999, clip=0.0, reg=0.0): +# return ModuleOptimizer( +# ModuleOptimizer.OptimizerType.ADAM, +# reg_type, +# net, +# { +# "learning_rate": learning_rate, +# "beta1": beta1, +# "beta2": beta2, +# "grad_clip": clip, +# "reg": reg, +# } +# ) +# +# def sgd(net, reg_type=ModuleOptimizer.RegularizerType.None_, learning_rate=0.001, clip=0.0, reg=0.0): +# return ModuleOptimizer( +# ModuleOptimizer.OptimizerType.SGD, +# reg_type, +# net, +# { +# "learning_rate": learning_rate, +# "grad_clip": clip, +# "reg": reg, +# } +# ) +# +# def rmsprop(net, reg_type=ModuleOptimizer.RegularizerType.None_, learning_rate=0.001, decay=0.9, clip=0.0, reg=0.0): +# return ModuleOptimizer( +# ModuleOptimizer.OptimizerType.RMSProp, +# reg_type, +# net, +# { +# "learning_rate": learning_rate, +# "decay": decay, +# "grad_clip": clip, +# "reg": reg, +# } +# ) \ No newline at end of file diff --git a/Python/TensorFrost/random.py b/Python/TensorFrost/random.py index de3d0ad0..ab91f028 100644 --- a/Python/TensorFrost/random.py +++ b/Python/TensorFrost/random.py @@ -1,45 +1,45 @@ -from . import TensorFrost as tf - -def randn2(shape, seed=0): - #Box-Muller transform - r1 = tf.random_value(shape, seed=seed) - r2 = tf.random_value(shape, seed=tf.hash(seed)) - rho = tf.sqrt(-2.0*tf.log(tf.max(1e-6, r1))) - theta = 2.0*tf.pi*r2 - return rho*tf.cos(theta), rho*tf.sin(theta) - -def randn(shape, seed=0): - return randn2(shape, seed=seed)[0] - -def rand(shape, seed=0): - return tf.random_value(shape, seed=seed) - -def randn_like(tensor, seed=0): - return randn(tensor.shape, seed=seed) - -def rand_like(tensor, seed=0): - return rand(tensor.shape, seed=seed) - -def rand_int(seed, max_value): - return tf.int(tf.pcg(tf.uint(seed)) % tf.uint(max_value)) - -def xor_swap(idx, n, seed): - xor_seed = rand_int(seed, n) - xor_idx = (idx ^ xor_seed) - max_idx = tf.max(idx, xor_idx) - min_idx = tf.min(idx, xor_idx) - swap = rand_int(min_idx * 451 + seed, 2) == 0 - return tf.select(swap & (max_idx < n), xor_idx, idx) - -def reverse(idx, n): - return n - 1 - idx - -def shuffle(idx, n, seed = 0, iters = 16): - for i in range(iters): - idx = xor_swap(idx, n, seed + i) - idx = reverse(idx, n) - return idx - -def permutation(n, seed = 0): - idx = tf.indices([n])[0] - return shuffle(idx, n, seed) \ No newline at end of file +# from . import TensorFrost as tf +# +# def randn2(shape, seed=0): +# #Box-Muller transform +# r1 = tf.random_value(shape, seed=seed) +# r2 = tf.random_value(shape, seed=tf.hash(seed)) +# rho = tf.sqrt(-2.0*tf.log(tf.max(1e-6, r1))) +# theta = 2.0*tf.pi*r2 +# return rho*tf.cos(theta), rho*tf.sin(theta) +# +# def randn(shape, seed=0): +# return randn2(shape, seed=seed)[0] +# +# def rand(shape, seed=0): +# return tf.random_value(shape, seed=seed) +# +# def randn_like(tensor, seed=0): +# return randn(tensor.shape, seed=seed) +# +# def rand_like(tensor, seed=0): +# return rand(tensor.shape, seed=seed) +# +# def rand_int(seed, max_value): +# return tf.int(tf.pcg(tf.uint(seed)) % tf.uint(max_value)) +# +# def xor_swap(idx, n, seed): +# xor_seed = rand_int(seed, n) +# xor_idx = (idx ^ xor_seed) +# max_idx = tf.max(idx, xor_idx) +# min_idx = tf.min(idx, xor_idx) +# swap = rand_int(min_idx * 451 + seed, 2) == 0 +# return tf.select(swap & (max_idx < n), xor_idx, idx) +# +# def reverse(idx, n): +# return n - 1 - idx +# +# def shuffle(idx, n, seed = 0, iters = 16): +# for i in range(iters): +# idx = xor_swap(idx, n, seed + i) +# idx = reverse(idx, n) +# return idx +# +# def permutation(n, seed = 0): +# idx = tf.indices([n])[0] +# return shuffle(idx, n, seed) \ No newline at end of file diff --git a/Python/TensorFrost/regularizers.py b/Python/TensorFrost/regularizers.py index b384bcef..989767cc 100644 --- a/Python/TensorFrost/regularizers.py +++ b/Python/TensorFrost/regularizers.py @@ -1,5 +1,5 @@ -from .optimizers import * - -l1 = ModuleOptimizer.RegularizerType.L1 -l2 = ModuleOptimizer.RegularizerType.L2 -none = ModuleOptimizer.RegularizerType.None_ \ No newline at end of file +# from .optimizers import * +# +# l1 = ModuleOptimizer.RegularizerType.L1 +# l2 = ModuleOptimizer.RegularizerType.L2 +# none = ModuleOptimizer.RegularizerType.None_ \ No newline at end of file diff --git a/Python/TensorFrost/sort.py b/Python/TensorFrost/sort.py index 4d585c64..4a236bb9 100644 --- a/Python/TensorFrost/sort.py +++ b/Python/TensorFrost/sort.py @@ -1,187 +1,516 @@ +from __future__ import annotations + +from dataclasses import dataclass +from importlib import resources +from typing import Dict, Optional, Tuple + +import numpy as np + from . import TensorFrost as tf -#in-place bitonic sort -def bitonic(keys, values = None): - tf.region_begin('Bitonic sort') - keys = tf.copy(keys) - if values is not None: - values = tf.copy(values) - element_count = keys.shape[0] - log2_count = tf.int(tf.ceil(tf.log2(tf.float(element_count)))) - count_round = 1 << log2_count - idx = tf.indices([count_round / 2])[0] - with tf.loop(log2_count) as k: - with tf.loop(k+1) as j: - s = 1 << (k-j) - m_inner = s - 1 - m_outer = ~m_inner - m_xor = s + tf.select(j == 0, m_inner, 0) - - id1 = (2 * (idx & m_outer) + (idx & m_inner)) - id2 = id1 ^ m_xor - key1, key2 = keys[id1], keys[id2] - with tf.if_cond((key1 >= key2) & (id1 < element_count) & (id2 < element_count)): - if values is not None: - val1, val2 = values[id1], values[id2] - values[id1] = val2 - values[id2] = val1 - keys[id1] = key2 - keys[id2] = key1 - - tf.region_end('Bitonic sort') - if values is not None: - return keys, values - else: - return keys - -#histogram radix sort -def radix(keys, values = None, bits_per_pass = 6, max_bits = 32): - def prefix_sum_grouped(A, axis = -1): - axis = len(A.shape) + axis if axis < 0 else axis - group_size = 64 - grouped = tf.split_dim(A, group_size, axis) - group_scan = tf.prefix_sum(tf.sum(grouped, axis = axis + 1), axis = axis) - ids = grouped.indices - gid, eid = ids[axis], ids[axis + 1] - ids = [ids[i] for i in range(len(ids)) if i != axis + 1] - ids[axis] = gid - 1 - group_scan = tf.prefix_sum(grouped + tf.select((gid == 0) | (eid != 0), tf.uint(0), group_scan[tuple(ids)]), axis = axis + 1) - full_scan = tf.merge_dim(group_scan, target_size = A.shape[axis], axis = axis + 1) - return full_scan - - sign_bit = ~tf.uint(0x7FFFFFFF) - - def map_float_to_uint(x): - # Convert float to uint representation - ux = tf.asuint(x) - # Compute mask - mask = tf.select((ux >> 31) == 1, ~tf.uint(0), sign_bit) - # Apply XOR - return ux ^ mask - - def map_uint_to_float(x): - # Compute mask - mask = tf.select((x >> 31) == 0, ~tf.uint(0), sign_bit) - # Apply XOR and convert back to float - return tf.asfloat(x ^ mask) - - def map_int_to_uint(x): - return tf.asuint(x) ^ sign_bit - - def map_uint_to_int(x): - return tf.asint(x ^ sign_bit) - - tf.region_begin('Radix sort') - - has_values = values is not None - - keys = tf.copy(keys) - if has_values: - values = tf.copy(values) - - original_type = keys.type - if(original_type == tf.float32): - keys = map_float_to_uint(keys) - - if(original_type == tf.int32): - keys = map_int_to_uint(keys) - - iters = (max_bits + bits_per_pass - 1) // bits_per_pass - group_size = 128 - histogram_size = 2 ** bits_per_pass - - def GetBits(A, i): - return (A >> (i * bits_per_pass)) & tf.uint(histogram_size - 1) - - keys1 = tf.buffer(keys.shape, keys.type) - values1 = None - - if has_values: - values1 = tf.buffer(values.shape, values.type) - - with tf.loop(iters // 2) as iter: - def SortIteration(keys_in, keys_out, values_in, values_out, iter): - tf.region_begin('Radix sort iteration') - grouped = tf.split_dim(GetBits(keys_in, iter), group_size) - - # Do a packed histogram, since we sum 128 elements at a time, we can pack 4 values into a single uint32 - g, e, i = tf.indices([grouped.shape[0], grouped.shape[1], tf.int(histogram_size/4)]) - this_key = grouped[g, e] - packed_is_bit = (tf.uint(this_key == tf.uint(4*i))) + (tf.uint(this_key == tf.uint(4*i+1)) << 8) + (tf.uint(this_key == tf.uint(4*i+2)) << 16) + (tf.uint(this_key == tf.uint(4*i+3)) << 24) - packed_is_bit = tf.select((g*group_size + e) < keys_in.shape[0], packed_is_bit, tf.uint(0)) - group_histogram_packed = tf.sum(packed_is_bit, axis = 1) - - g, i = tf.indices([grouped.shape[0], histogram_size]) - group_histogram = tf.uint((group_histogram_packed[g, i / 4] >> (8*(i % 4))) & tf.uint(0xFF)) - - group_histogram_scan = prefix_sum_grouped(group_histogram, axis = 0) - i, = tf.indices([histogram_size]) - total_bit_histogram = tf.prefix_sum(group_histogram_scan[group_histogram_scan.shape[0] - 1, i]) - - with tf.kernel(grouped.shape, group_size=[group_size]) as (g, e): - if(tf.current_backend() == tf.cpu): #dont use group barriers on CPU - doesn't work - element = g * group_size + e - with tf.if_cond(element < keys_in.shape[0]): - old_key = keys_in[element] - old_val = values_in[element] - bit = GetBits(old_key, iter) - total_offset = tf.select(g == 0, tf.uint(0), group_histogram_scan[g - 1, bit]) + tf.select(bit == tf.uint(0), tf.uint(0), total_bit_histogram[bit - tf.uint(1)]) - with tf.loop(e) as j: - total_offset.val += tf.uint(grouped[g, j] == bit) - keys_out[total_offset] = old_key - values_out[total_offset] = old_val - else: - temp = tf.group_buffer(group_size, tf.uint32) - half_count = tf.group_buffer(histogram_size, tf.uint32) - gtid = g.block_thread_index(0) - - #initialize counters - for i in range((histogram_size + group_size - 1) // group_size): - index = gtid + i * group_size - with tf.if_cond(index < histogram_size): - half_count[index] = 0 - tf.group_barrier() - - element = g * group_size + e - with tf.if_cond(element < keys_in.shape[0]): - old_key = keys_in[element] - bit = GetBits(old_key, iter) - temp[gtid] = bit - - #count number of bits set in previous sub groups - quarter_index = e / (group_size // 4) - with tf.if_cond(quarter_index < 3): - tf.scatterAdd(half_count[bit], tf.uint(quarter_index < 1) | (tf.uint(quarter_index < 2) << 8) | (tf.uint(quarter_index < 3) << 16)) - - tf.group_barrier() - - if has_values: - old_val = values_in[element] - - total_offset = tf.select(g == 0, tf.uint(0), group_histogram_scan[g - 1, tf.int(bit)]) + tf.select(tf.int(bit) == 0, tf.uint(0), total_bit_histogram[tf.int(bit) - 1]) - total_offset += tf.select(quarter_index > 0, (half_count[bit] >> (8*(quarter_index-1))) & tf.uint(0xFF), tf.uint(0)) - begin_index = quarter_index * (group_size // 4) - with tf.loop(begin_index, e) as j: - total_offset.val += tf.uint(temp[j] == bit) - keys_out[total_offset] = old_key - - if has_values: - values_out[total_offset] = old_val - - tf.region_end('Radix sort iteration') - - SortIteration(keys, keys1, values, values1, 2 * iter) - SortIteration(keys1, keys, values1, values, 2 * iter + 1) - - tf.region_end('Radix sort') - - if(original_type == tf.float32): - keys = map_uint_to_float(keys) - - if(original_type == tf.int32): - keys = map_uint_to_int(keys) - - if has_values: - return keys, values - else: - return keys +__all__ = ["HistogramRadixSort", "radix_sort"] + +_TYPE_CODES: Dict[str, np.uint32] = { + "uint": np.uint32(0), + "int": np.uint32(1), + "float": np.uint32(2), +} + + +def _dispatch_groups(work_items: int, threads_per_group: int) -> int: + if work_items <= 0: + return 0 + return (work_items + threads_per_group - 1) // threads_per_group + + +def _prepare_keys(keys: np.ndarray) -> Tuple[np.ndarray, np.dtype, str]: + array = np.asarray(keys) + if array.ndim != 1: + raise ValueError("radix_sort expects a 1D array of keys") + + dtype = array.dtype + if dtype == np.uint32: + return array, dtype, "uint" + + if dtype == np.int32: + return array, dtype, "int" + + if dtype == np.float32: + return array, dtype, "float" + + raise TypeError(f"Unsupported key dtype {dtype}; expected uint32, int32, or float32") + + +def _prepare_values(values: np.ndarray) -> Tuple[np.ndarray, np.dtype]: + array = np.asarray(values) + if array.ndim != 1: + raise ValueError("radix_sort expects a 1D array of values when provided") + + dtype = array.dtype + if dtype not in (np.uint32, np.int32, np.float32): + raise TypeError(f"Unsupported value dtype {dtype}; expected uint32, int32, or float32") + + return array, dtype + + +def _load_shader_source(filename: str) -> str: + package = f"{__package__}.shaders.radix" + try: + return resources.files(package).joinpath(filename).read_text(encoding="utf-8") # type: ignore[attr-defined] + except AttributeError: + return resources.read_text(package, filename) + + +@dataclass(frozen=True) +class _SorterKey: + bits_per_pass: int + block_size: int + group_size: int + + +class HistogramRadixSort: + """GPU histogram radix sort implemented with Slang + Vulkan.""" + + def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: int = 128) -> None: + if bits_per_pass <= 0: + raise ValueError("bits_per_pass must be positive") + if bits_per_pass > 8: + raise ValueError("bits_per_pass must be <= 8 to fit within MAX_HIST_SIZE") + if group_size != 128: + raise ValueError("This implementation currently requires group_size == 128") + if block_size <= 0 or block_size > 1024: + raise ValueError("block_size must be within (0, 1024]") + + self.bits_per_pass = bits_per_pass + self.block_size = block_size + self.group_size = group_size + self.histogram_size = 1 << bits_per_pass + + self._map_to_uint_program = tf.createComputeProgramFromSlang( + "radix_map_to_uint", + _load_shader_source("map_to_uint.slang"), + "csMapToUint", + ro_count=1, + rw_count=1, + push_constant_size=8, + ) + self._map_from_uint_program = tf.createComputeProgramFromSlang( + "radix_map_from_uint", + _load_shader_source("map_from_uint.slang"), + "csMapFromUint", + ro_count=1, + rw_count=1, + push_constant_size=8, + ) + + self._histogram_program = tf.createComputeProgramFromSlang( + "radix_histogram", + _load_shader_source("histogram.slang"), + "csHistogram", + ro_count=1, + rw_count=1, + push_constant_size=32, + ) + self._unpack_program = tf.createComputeProgramFromSlang( + "radix_unpack", + _load_shader_source("unpack.slang"), + "csUnpack", + ro_count=1, + rw_count=1, + push_constant_size=32, + ) + self._prefix_local_program = tf.createComputeProgramFromSlang( + "radix_prefix_local", + _load_shader_source("prefix_local.slang"), + "csPrefixLocal", + ro_count=1, + rw_count=2, + push_constant_size=32, + ) + self._prefix_blocks_program = tf.createComputeProgramFromSlang( + "radix_prefix_blocks", + _load_shader_source("prefix_block.slang"), + "csPrefixBlocks", + ro_count=1, + rw_count=1, + push_constant_size=32, + ) + self._prefix_accum_program = tf.createComputeProgramFromSlang( + "radix_prefix_accum", + _load_shader_source("prefix_accum.slang"), + "csPrefixAccumulate", + ro_count=1, + rw_count=1, + push_constant_size=32, + ) + self._bucket_scan_program = tf.createComputeProgramFromSlang( + "radix_bucket_scan", + _load_shader_source("bucket_scan.slang"), + "csBucketScan", + ro_count=1, + rw_count=1, + push_constant_size=32, + ) + scatter_source = f"#define TF_HISTOGRAM_SIZE {self.histogram_size}u\n" + _load_shader_source("scatter.slang") + self._scatter_program = tf.createComputeProgramFromSlang( + "radix_scatter", + scatter_source, + "csScatter", + ro_count=4, + rw_count=2, + ) + + self._dummy_values_buffer = tf.createBuffer(1, 4, False) + + def close(self) -> None: + for attr in ( + "_map_to_uint_program", + "_map_from_uint_program", + "_histogram_program", + "_unpack_program", + "_prefix_local_program", + "_prefix_blocks_program", + "_prefix_accum_program", + "_bucket_scan_program", + "_scatter_program", + ): + setattr(self, attr, None) + self._dummy_values_buffer = None + + def sort( + self, + keys: np.ndarray, + values: Optional[np.ndarray] = None, + *, + max_bits: int = 32, + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + keys_array, key_dtype, key_kind = _prepare_keys(keys) + element_count = int(keys_array.shape[0]) + + if values is not None: + values_array, values_dtype = _prepare_values(values) + if values_array.shape[0] != element_count: + raise ValueError("values must have the same length as keys") + else: + values_array = None + values_dtype = None + + if element_count == 0: + empty_keys = keys_array.copy() + if values_array is None: + return empty_keys, None + return empty_keys, values_array.copy() + + max_bits = int(min(max_bits, 32)) + histogram_size = self.histogram_size + mask = np.uint32(histogram_size - 1) + + num_groups = max((element_count + self.group_size - 1) // self.group_size, 1) + block_count = max((num_groups + self.block_size - 1) // self.block_size, 1) + packed_count = (histogram_size + 3) // 4 + passes = max((max_bits + self.bits_per_pass - 1) // self.bits_per_pass, 1) + + params_array = np.zeros(8, dtype=np.uint32) + params_array[0] = np.uint32(element_count) + params_array[1] = np.uint32(histogram_size) + params_array[3] = mask + params_array[4] = np.uint32(num_groups) + params_array[5] = np.uint32(self.block_size) + params_array[6] = np.uint32(block_count) + params_array[7] = np.uint32(1 if values_array is not None else 0) + + map_params = np.zeros(2, dtype=np.uint32) + map_params[0] = np.uint32(element_count) + map_params[1] = _TYPE_CODES[key_kind] + + key_buffers = [tf.createBuffer(max(element_count, 1), 4, False) for _ in range(2)] + key_buffers[0].setData(keys_array) + + if values_array is not None: + value_buffers = [tf.createBuffer(max(element_count, 1), 4, False) for _ in range(2)] + value_buffers[0].setData(values_array) + else: + dummy = self._dummy_values_buffer + value_buffers = [dummy, dummy] + + packed_hist_buffer = tf.createBuffer(max(packed_count * num_groups, 1), 4, False) + group_hist_buffer = tf.createBuffer(max(histogram_size * num_groups, 1), 4, False) + prefix_buffer = tf.createBuffer(max(histogram_size * num_groups, 1), 4, False) + block_totals_buffer = tf.createBuffer(max(histogram_size * block_count, 1), 4, False) + block_prefix_buffer = tf.createBuffer(max(histogram_size * block_count, 1), 4, False) + bucket_scan_buffer = tf.createBuffer(max(histogram_size, 1), 4, False) + + map_groups = _dispatch_groups(element_count, self.group_size) + reduction_group_size = 64 + unpack_groups = _dispatch_groups(histogram_size * num_groups, reduction_group_size) + prefix_local_groups = _dispatch_groups(histogram_size * block_count, reduction_group_size) + prefix_block_groups = _dispatch_groups(histogram_size, reduction_group_size) + prefix_accum_groups = _dispatch_groups(histogram_size * block_count, reduction_group_size) + bucket_scan_groups = _dispatch_groups(histogram_size, reduction_group_size) + scatter_groups = num_groups + histogram_groups = num_groups + + self._map_to_uint_program.run( + [key_buffers[0]], + [key_buffers[1]], + map_groups, + map_params, + ) + + key_in = key_buffers[1] + key_out = key_buffers[0] + val_in, val_out = value_buffers + + for pass_index in range(passes): + params_array[2] = np.uint32(pass_index * self.bits_per_pass) + + self._histogram_program.run( + [key_in], + [packed_hist_buffer], + histogram_groups, + params_array, + ) + + self._unpack_program.run( + [packed_hist_buffer], + [group_hist_buffer], + unpack_groups, + params_array, + ) + + self._prefix_local_program.run( + [group_hist_buffer], + [prefix_buffer, block_totals_buffer], + prefix_local_groups, + params_array, + ) + + self._prefix_blocks_program.run( + [block_totals_buffer], + [block_prefix_buffer], + prefix_block_groups, + params_array, + ) + + self._prefix_accum_program.run( + [block_prefix_buffer], + [prefix_buffer], + prefix_accum_groups, + params_array, + ) + + self._bucket_scan_program.run( + [prefix_buffer], + [bucket_scan_buffer], + bucket_scan_groups, + params_array, + ) + + self._scatter_program.run( + [key_in, val_in, prefix_buffer, bucket_scan_buffer], + [key_out, val_out], + scatter_groups, + params_array, + ) + + key_in, key_out = key_out, key_in + if values_array is not None: + val_in, val_out = val_out, val_in + + self._map_from_uint_program.run( + [key_in], + [key_out], + map_groups, + map_params, + ) + + sorted_keys = key_out.getData(key_dtype, element_count) + if values_array is not None and values_dtype is not None: + sorted_values = val_in.getData(values_dtype, element_count) + else: + sorted_values = None + + return sorted_keys, sorted_values + + +_SORTER_CACHE: Dict[_SorterKey, HistogramRadixSort] = {} + + +def _get_sorter(bits_per_pass: int, block_size: int, group_size: int) -> HistogramRadixSort: + key = _SorterKey(bits_per_pass, block_size, group_size) + sorter = _SORTER_CACHE.get(key) + if sorter is None: + sorter = HistogramRadixSort(bits_per_pass=bits_per_pass, block_size=block_size, group_size=group_size) + _SORTER_CACHE[key] = sorter + return sorter + + +def radix_sort( + keys: np.ndarray, + values: Optional[np.ndarray] = None, + *, + bits_per_pass: int = 6, + max_bits: int = 32, + block_size: int = 64, + group_size: int = 128, +): + """Run the GPU histogram radix sort on the provided keys (and optional values). + + Returns the sorted keys, and when ``values`` is provided also returns the permuted values. + """ + + sorter = _get_sorter(bits_per_pass, block_size, group_size) + sorted_keys, sorted_values = sorter.sort(keys, values, max_bits=max_bits) + if values is None: + return sorted_keys + return sorted_keys, sorted_values +# def radix(keys, values = None, bits_per_pass = 6, max_bits = 32): +# def prefix_sum_grouped(A, axis = -1): +# axis = len(A.shape) + axis if axis < 0 else axis +# group_size = 64 +# grouped = tf.split_dim(A, group_size, axis) +# group_scan = tf.prefix_sum(tf.sum(grouped, axis = axis + 1), axis = axis) +# ids = grouped.indices +# gid, eid = ids[axis], ids[axis + 1] +# ids = [ids[i] for i in range(len(ids)) if i != axis + 1] +# ids[axis] = gid - 1 +# group_scan = tf.prefix_sum(grouped + tf.select((gid == 0) | (eid != 0), tf.uint(0), group_scan[tuple(ids)]), axis = axis + 1) +# full_scan = tf.merge_dim(group_scan, target_size = A.shape[axis], axis = axis + 1) +# return full_scan +# +# sign_bit = ~tf.uint(0x7FFFFFFF) +# +# def map_float_to_uint(x): +# # Convert float to uint representation +# ux = tf.asuint(x) +# # Compute mask +# mask = tf.select((ux >> 31) == 1, ~tf.uint(0), sign_bit) +# # Apply XOR +# return ux ^ mask +# +# def map_uint_to_float(x): +# # Compute mask +# mask = tf.select((x >> 31) == 0, ~tf.uint(0), sign_bit) +# # Apply XOR and convert back to float +# return tf.asfloat(x ^ mask) +# +# def map_int_to_uint(x): +# return tf.asuint(x) ^ sign_bit +# +# def map_uint_to_int(x): +# return tf.asint(x ^ sign_bit) +# +# tf.region_begin('Radix sort') +# +# has_values = values is not None +# +# keys = tf.copy(keys) +# if has_values: +# values = tf.copy(values) +# +# original_type = keys.type +# if(original_type == tf.float32): +# keys = map_float_to_uint(keys) +# +# if(original_type == tf.int32): +# keys = map_int_to_uint(keys) +# +# iters = (max_bits + bits_per_pass - 1) // bits_per_pass +# group_size = 128 +# histogram_size = 2 ** bits_per_pass +# +# def GetBits(A, i): +# return (A >> (i * bits_per_pass)) & tf.uint(histogram_size - 1) +# +# keys1 = tf.buffer(keys.shape, keys.type) +# values1 = None +# +# if has_values: +# values1 = tf.buffer(values.shape, values.type) +# +# with tf.loop(iters // 2) as iter: +# def SortIteration(keys_in, keys_out, values_in, values_out, iter): +# tf.region_begin('Radix sort iteration') +# grouped = tf.split_dim(GetBits(keys_in, iter), group_size) +# +# # Do a packed histogram, since we sum 128 elements at a time, we can pack 4 values into a single uint32 +# g, e, i = tf.indices([grouped.shape[0], grouped.shape[1], tf.int(histogram_size/4)]) +# this_key = grouped[g, e] +# packed_is_bit = (tf.uint(this_key == tf.uint(4*i))) + (tf.uint(this_key == tf.uint(4*i+1)) << 8) + (tf.uint(this_key == tf.uint(4*i+2)) << 16) + (tf.uint(this_key == tf.uint(4*i+3)) << 24) +# packed_is_bit = tf.select((g*group_size + e) < keys_in.shape[0], packed_is_bit, tf.uint(0)) +# group_histogram_packed = tf.sum(packed_is_bit, axis = 1) +# +# g, i = tf.indices([grouped.shape[0], histogram_size]) +# group_histogram = tf.uint((group_histogram_packed[g, i / 4] >> (8*(i % 4))) & tf.uint(0xFF)) +# +# group_histogram_scan = prefix_sum_grouped(group_histogram, axis = 0) +# i, = tf.indices([histogram_size]) +# total_bit_histogram = tf.prefix_sum(group_histogram_scan[group_histogram_scan.shape[0] - 1, i]) +# +# with tf.kernel(grouped.shape, group_size=[group_size]) as (g, e): +# if(tf.current_backend() == tf.cpu): #dont use group barriers on CPU - doesn't work +# element = g * group_size + e +# with tf.if_cond(element < keys_in.shape[0]): +# old_key = keys_in[element] +# old_val = values_in[element] +# bit = GetBits(old_key, iter) +# total_offset = tf.select(g == 0, tf.uint(0), group_histogram_scan[g - 1, bit]) + tf.select(bit == tf.uint(0), tf.uint(0), total_bit_histogram[bit - tf.uint(1)]) +# with tf.loop(e) as j: +# total_offset.val += tf.uint(grouped[g, j] == bit) +# keys_out[total_offset] = old_key +# values_out[total_offset] = old_val +# else: +# temp = tf.group_buffer(group_size, tf.uint32) +# half_count = tf.group_buffer(histogram_size, tf.uint32) +# gtid = g.block_thread_index(0) +# +# #initialize counters +# for i in range((histogram_size + group_size - 1) // group_size): +# index = gtid + i * group_size +# with tf.if_cond(index < histogram_size): +# half_count[index] = 0 +# tf.group_barrier() +# +# element = g * group_size + e +# with tf.if_cond(element < keys_in.shape[0]): +# old_key = keys_in[element] +# bit = GetBits(old_key, iter) +# temp[gtid] = bit +# +# #count number of bits set in previous sub groups +# quarter_index = e / (group_size // 4) +# with tf.if_cond(quarter_index < 3): +# tf.scatterAdd(half_count[bit], tf.uint(quarter_index < 1) | (tf.uint(quarter_index < 2) << 8) | (tf.uint(quarter_index < 3) << 16)) +# +# tf.group_barrier() +# +# if has_values: +# old_val = values_in[element] +# +# total_offset = tf.select(g == 0, tf.uint(0), group_histogram_scan[g - 1, tf.int(bit)]) + tf.select(tf.int(bit) == 0, tf.uint(0), total_bit_histogram[tf.int(bit) - 1]) +# total_offset += tf.select(quarter_index > 0, (half_count[bit] >> (8*(quarter_index-1))) & tf.uint(0xFF), tf.uint(0)) +# begin_index = quarter_index * (group_size // 4) +# with tf.loop(begin_index, e) as j: +# total_offset.val += tf.uint(temp[j] == bit) +# keys_out[total_offset] = old_key +# +# if has_values: +# values_out[total_offset] = old_val +# +# tf.region_end('Radix sort iteration') +# +# SortIteration(keys, keys1, values, values1, 2 * iter) +# SortIteration(keys1, keys, values1, values, 2 * iter + 1) +# +# tf.region_end('Radix sort') +# +# if(original_type == tf.float32): +# keys = map_uint_to_float(keys) +# +# if(original_type == tf.int32): +# keys = map_uint_to_int(keys) +# +# if has_values: +# return keys, values +# else: +# return keys diff --git a/Python/pyproject.toml b/Python/pyproject.toml index 39eb00ff..55af72a3 100644 --- a/Python/pyproject.toml +++ b/Python/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "TensorFrost" -version = "0.7.4" +version = "2.0.0.dev0" description = "A static optimizing tensor compiler with a Python frontend" authors = [{name = "Mykhailo Moroz", email = "michael08840884@gmail.com"}] requires-python = ">=3.7" diff --git a/README.md b/README.md index dd9a14a3..e7b6787e 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ pip install tensorfrost ## From source -You need to have CMake installed to build the library. +You need to have CMake and Vulkan SDK installed to build the library. First clone the repository: ```bash @@ -689,11 +689,12 @@ You can also specify the clipping type for the gradients, by default the value o For debugging convenience there are 2 function types that you can call inside a tensor program: ```python +tf.renderdoc_is_available() tf.renderdoc_start_capture() tf.renderdoc_end_capture() ``` -These functions will start and end a RenderDoc capture, only if python is started from the RenderDoc GUI. This is useful for debugging the OpenGL backend, as it allows you to inspect compiled kernel execution, its code and buffers. +These functions will start and end a RenderDoc capture, only if python is started from the RenderDoc GUI. Call `tf.renderdoc_is_available()` first to check whether RenderDoc is attached so you can skip capture logic when it isn't. This is useful for debugging the OpenGL backend, as it allows you to inspect compiled kernel execution, its code and buffers. ```python tf.region_begin('Region name') diff --git a/TensorFrost/Backend/Backend.cpp b/TensorFrost/Backend/Backend.cpp deleted file mode 100644 index 8163a39e..00000000 --- a/TensorFrost/Backend/Backend.cpp +++ /dev/null @@ -1,187 +0,0 @@ -#include "Backend.h" - -namespace TensorFrost { - -BackendType current_backend = BackendType::NotInitialized; -CodeGenLang current_kernel_lang = CodeGenLang::CPP; -CodeGenLang current_main_lang = CodeGenLang::CPP; -bool strip_debug_names = false; - -void InitializeBackend(BackendType backendType, const string& compilerOptions, CodeGenLang kernelType) { - if (current_backend != BackendType::NotInitialized) { - cout << "Warning: Backend already initialized, stopping current backend\n" << endl; - - switch (current_backend) { - case BackendType::CPU: - break; - case BackendType::Vulkan: - break; - case BackendType::OpenGL: - StopOpenGL(); - break; - default: - throw std::runtime_error("Backend not implemented"); - } - } - - if (!compilerOptions.empty()) { - kernelCompileOptions = compilerOptions; - } else if(backendType != BackendType::CPU) { - kernelCompileOptions = ""; //no need for cpu optimizations on other backends - } else { -#ifdef _WIN32 - kernelCompileOptions = "/O2 /fp:fast /openmp"; -#else - kernelCompileOptions = "-O3 -ffast-math -fopenmp"; -#endif - } - -#ifdef _DEBUG -#ifdef _WIN32 - kernelCompileOptions = "/Zi"; -#else - kernelCompileOptions = "-g"; -#endif -#endif - - current_backend = backendType; - - switch (backendType) { - case BackendType::CPU: - case BackendType::CodeGen: - current_kernel_lang = CodeGenLang::CPP; - global_memory_manager = new CpuMemoryManager(); - global_kernel_manager = new CpuKernelManager(); - break; - case BackendType::Vulkan: - throw std::runtime_error("Vulkan backend not implemented yet"); - current_kernel_lang = CodeGenLang::GLSL; - break; - case BackendType::OpenGL: - StartOpenGL(); - current_kernel_lang = CodeGenLang::GLSL; - global_memory_manager = new OpenGLMemoryManager(); - global_kernel_manager = new OpenGLKernelManager(); - break; - default: - throw std::runtime_error("Backend not implemented"); - } - - if (kernelType != CodeGenLang::None) { - current_kernel_lang = kernelType; - } -} - -void CompileKernels(Program* program) { - auto start_time = chrono::high_resolution_clock::now(); - for(auto& kernel : program->kernels_) { - switch (current_backend) { - case BackendType::CPU: - //already in the host program - break; - case BackendType::Vulkan: - throw std::runtime_error("Vulkan backend not implemented yet"); - case BackendType::OpenGL: - ((OpenGLKernelManager*)global_kernel_manager)->CompileKernel(&kernel); - break; - default: - throw std::runtime_error("Backend not implemented"); - } - } - auto end_time = chrono::high_resolution_clock::now(); - float milliseconds = chrono::duration(end_time - start_time).count(); - program->shader_compile_time = milliseconds; -} - -TFTensor Allocator(const char* name, const size_t* a, size_t dim, TFDataFormat format, void* data) { - try { - vector shape(a, a + dim); - return *global_memory_manager->AllocateTensor(shape, format, name); - } catch (const std::exception& e) { - size_t size = 1; - for (size_t i = 0; i < dim; i++) { - size *= a[i]; - } - throw std::runtime_error("Error allocating tensor " + string(name) + ": " + e.what() + ", requested size: " + to_string(size)); - } -} - -void Deallocator(TFTensor a, void* data) { - global_memory_manager->DeallocateTensor(a); -} - -uint Readback(TFTensor a, size_t index, void* data) { - return global_memory_manager->ReadbackValue(&a, index); -} - -void Writeback(TFTensor a, size_t index, uint32_t value, void* data) { - global_memory_manager->WritebackValue(&a, index, value); -} - -void Dispatch(TFDispatchInfo info, void* data) { - global_kernel_manager->DispatchKernel(info); -} - -void Region(const char* name, bool begin, void* data) { - if (current_backend == BackendType::OpenGL) { - if (begin) { - StartDebugRegion(name); - } else { - EndDebugRegion(); - } - } -} - -//#define PROFILE_EXECUTION - -vector ExecuteProgram( - Program* program, vector inputs) { - - if (current_backend == BackendType::CodeGen) { - throw std::runtime_error("Cannot execute program with code generation backend"); - } - - int memory_input_count = (int)program->ir_->input_memory_map.size(); - - if (memory_input_count != inputs.size()) { - throw std::runtime_error( - "Invalid number of inputs for TensorProgram. Expected " + - to_string(memory_input_count) + ", got " + to_string(inputs.size())); - } - - vector input_tensors; - for (int i = 0; i < memory_input_count; i++) { - input_tensors.push_back(*inputs[i]); - } - - unordered_map output_memory_map = program->ir_->output_memory_map; - int output_count = (int)output_memory_map.size(); - - TFTensor* in = input_tensors.data(); - TFTensor* out = new TFTensor[output_count]; - -#ifdef PROFILE_EXECUTION - auto start = chrono::high_resolution_clock::now(); -#endif - try { - program->execute_callback(in, out, {Allocator, Deallocator, Readback, Writeback, Dispatch, Region, nullptr}); - } catch (const std::exception& e) { - throw std::runtime_error("Error executing program " + program->program_name + ": " + e.what()); - } - -#ifdef PROFILE_EXECUTION - Finish(); - auto end = chrono::high_resolution_clock::now(); - float milliseconds = chrono::duration(end - start).count(); - program->last_execution_time = milliseconds; -#endif - - vector outputs = vector(output_count); - for (int i = 0; i < output_count; i++) { - outputs[i] = &out[i]; - } - - return outputs; -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/Backend.h b/TensorFrost/Backend/Backend.h deleted file mode 100644 index dd4151f9..00000000 --- a/TensorFrost/Backend/Backend.h +++ /dev/null @@ -1,49 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "Backends/CPU/CPU.h" -#include "Backends/OpenGL/OpenGL.h" -#include "CodeGen/Generators.h" -#include "KernelManager.h" -#include "TensorMemory.h" -#include "RenderDoc.h" - -namespace TensorFrost { - -using namespace std; - -enum class BackendType { - CPU, - Vulkan, - OpenGL, - CodeGen, - NotInitialized -}; - -enum class CodeGenLang { - CPP, - HLSL, - GLSL, - None, -}; - -extern BackendType current_backend; -extern CodeGenLang current_kernel_lang; -extern CodeGenLang current_main_lang; -extern bool strip_debug_names; - -vector ExecuteProgram( - Program* program, vector inputs); - -void InitializeBackend(BackendType backendType, const string& compilerPath, CodeGenLang kernelType); - -void CompileKernels(Program* program); - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/Backends/CPU/CPU.h b/TensorFrost/Backend/Backends/CPU/CPU.h deleted file mode 100644 index c166ba59..00000000 --- a/TensorFrost/Backend/Backends/CPU/CPU.h +++ /dev/null @@ -1,5 +0,0 @@ -#pragma once - -#include "KernelCompiler.h" -#include "Memory.h" -#include "KernelManager.h" \ No newline at end of file diff --git a/TensorFrost/Backend/Backends/CPU/KernelCompiler.cpp b/TensorFrost/Backend/Backends/CPU/KernelCompiler.cpp deleted file mode 100644 index 6a220b0f..00000000 --- a/TensorFrost/Backend/Backends/CPU/KernelCompiler.cpp +++ /dev/null @@ -1,222 +0,0 @@ -#include "KernelCompiler.h" - -#include - -namespace TensorFrost { - -std::string kernelCompileOptions; - -//TODO: Add support for MacOS -#ifndef __APPLE__ -bool RunCompiler(char* tempPath, char* dllName, const char* sourcePath) { - std::basic_stringstream ss; - std::string output; - -#if defined(_WIN32) - ss << "powershell -command \"$VisualStudioPath = & \\\"${Env:ProgramFiles(x86)}\\Microsoft Visual Studio\\Installer\\vswhere.exe\\\" -latest -products * -property installationPath; & cmd.exe /C \\\"\"\\\"\\\"$VisualStudioPath\\VC\\Auxiliary\\Build\\vcvarsall.bat\\\"\\\" x64 && cl " - << kernelCompileOptions << " /LD " << tempPath - << sourcePath << " /Fe:" << dllName - << "\"\"\\\"\""; -#else - ss << "g++ " << kernelCompileOptions << " -shared -fPIC " << tempPath - << sourcePath << " -o " << dllName; -#endif - - std::basic_string command = ss.str(); - - // Run the compiler -#if defined(_WIN32) - SECURITY_ATTRIBUTES sa; - sa.nLength = sizeof(SECURITY_ATTRIBUTES); - sa.bInheritHandle = TRUE; - sa.lpSecurityDescriptor = NULL; - - HANDLE hReadPipe, hWritePipe; - if (!CreatePipe(&hReadPipe, &hWritePipe, &sa, 0)) { - throw std::runtime_error("Failed to create pipe"); - } - - STARTUPINFO si; - PROCESS_INFORMATION pi; - ZeroMemory(&si, sizeof(si)); - si.cb = sizeof(si); - si.hStdError = hWritePipe; - si.hStdOutput = hWritePipe; - si.dwFlags |= STARTF_USESTDHANDLES; - - if (!CreateProcess(nullptr, command.data(), nullptr, nullptr, TRUE, 0, nullptr, nullptr, &si, &pi)) { - throw std::runtime_error(std::string("Steps error: cannot create compiler process. Command line: ") + command.data() + "\n"); - } - - CloseHandle(hWritePipe); - - char buffer[4096]; - DWORD bytesRead; - while (ReadFile(hReadPipe, buffer, sizeof(buffer), &bytesRead, NULL) && bytesRead != 0) { - output.append(buffer, bytesRead); - } - - WaitForSingleObject(pi.hProcess, INFINITE); - - DWORD exit_code; - GetExitCodeProcess(pi.hProcess, &exit_code); - if (exit_code != 0) { - throw std::runtime_error( - "Steps error: compiler exited with non-zero exit code (Error " - "code: " + std::to_string(exit_code) + ")\nCompiler output:\n" + output); - } - - CloseHandle(pi.hProcess); - CloseHandle(pi.hThread); - CloseHandle(hReadPipe); -#else - FILE* pipe = popen(command.c_str(), "r"); - if (!pipe) { - throw std::runtime_error("popen() failed!"); - } - - char buffer[128]; - while (fgets(buffer, sizeof(buffer), pipe) != nullptr) { - output += buffer; - } - - int status = pclose(pipe); - if (status != 0) { - throw std::runtime_error( - "Steps error: compiler exited with non-zero exit code (Error " - "code: " + std::to_string(status) + ")\nCompiler output:\n" + output); - } -#endif - - return true; -} - -void CompileKernelLibrary(const string& sourceCode, char* tempPath, - char* dllName, size_t program_id) { - // Append a file name to the tempPath - std::string source_name = "generated_lib_" + std::to_string(program_id) + ".cpp"; - std::basic_stringstream ss; - ss << tempPath << source_name; - std::basic_string full_file_path = ss.str(); - - const std::string& file_path(full_file_path); - - cout << "Source path: " << file_path << endl; - - // Write the generated source code to a file - std::ofstream out_file(file_path); - if (!out_file) { - throw std::runtime_error( - "Steps error: cannot open file for writing generated source code"); - } - out_file << sourceCode; - out_file.close(); - - RunCompiler(tempPath, dllName, source_name.c_str()); -} -#endif - -void CompileAndLoadKernelModule(Program* program, size_t program_id) { -#ifndef __APPLE__ -#if defined(_WIN32) - char temp_path[MAX_PATH]; - DWORD path_length = GetTempPath(MAX_PATH, temp_path); - - if (path_length == 0) { - throw std::runtime_error("Steps error: cannot get temp path"); - } - - // Create a temporary library name - char temp_file_name[MAX_PATH]; - if (!GetTempFileName(temp_path, TEXT("lib"), 0, temp_file_name)) { - throw std::runtime_error("Steps error: cannot create temp file"); - } -#else - char temp_path[] = "/tmp/"; - char filename_template[] = "/tmp/tensorfrost_XXXXXX"; - char* temp_file_name = mktemp(filename_template); - if (!temp_file_name) { - throw std::runtime_error("Steps error: cannot create temp file"); - } -#endif - - cout << "Temp file: " << temp_file_name << endl; - - // Compile the library - CompileKernelLibrary(program->generated_code_, temp_path, temp_file_name, program_id); - - // Load the library - #if defined(_WIN32) - HMODULE lib_handle = LoadLibrary(temp_file_name); - if (!lib_handle) { - throw std::runtime_error("Steps error: cannot load generated library"); - } - #else - void* lib_handle = dlopen(temp_file_name, RTLD_LAZY); - if (!lib_handle) { - throw std::runtime_error("Steps error: cannot load generated library"); - } - #endif - - // Create lambda function to free the library - program->unload_callback = [lib_handle]() { - #if defined(_WIN32) - if (!FreeLibrary(lib_handle)) { - std::cerr << "Cannot free library: " << GetLastError() << '\n'; - } - #else - if (dlclose(lib_handle)) { - std::cerr << "Cannot free library: " << dlerror() << '\n'; - } - #endif - }; - - // Load the main function - #if defined(_WIN32) - auto main_callback = reinterpret_cast( - GetProcAddress(lib_handle, "main")); - #else - auto main_callback = reinterpret_cast( - dlsym(lib_handle, "main")); - #endif - - if (!main_callback) { - throw std::runtime_error("Steps error: cannot load main function"); - } - - // Set the execute callback - program->execute_callback = *main_callback; - - // load cpu kernel functions - if (current_backend == BackendType::CPU) - { - for (auto& kernel : program->kernels_) { - #if defined(_WIN32) - auto kernel_callback = reinterpret_cast( - GetProcAddress(lib_handle, kernel.kernel_name_.c_str())); - #else - auto kernel_callback = reinterpret_cast( - dlsym(lib_handle, kernel.kernel_name_.c_str())); - #endif - - if (!kernel_callback) { - throw std::runtime_error("Steps error: cannot load kernel function"); - } - - ((CpuKernelManager*)global_kernel_manager) - ->AddKernelFunction(&kernel, kernel_callback); - -#ifndef NDEBUG - cout << "Loaded kernel: " << kernel.kernel_name_ << endl; -#endif - } - } - - cout << "Successfully compiled and loaded kernel library." << endl; - -#else - throw std::runtime_error("Steps error: cannot compile and load kernel module on macOS"); -#endif -} - -} // namespace TensorFrost diff --git a/TensorFrost/Backend/Backends/CPU/KernelCompiler.h b/TensorFrost/Backend/Backends/CPU/KernelCompiler.h deleted file mode 100644 index 3f397f99..00000000 --- a/TensorFrost/Backend/Backends/CPU/KernelCompiler.h +++ /dev/null @@ -1,37 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#ifndef NOMINMAX -#define NOMINMAX -#endif - -#if defined(_WIN32) -#include -#else -#include -#include -#include -#endif - -#include "Backend/Backends/CPU/Memory.h" -#include "Backend/CodeGen/Generators.h" -#include "Backend/KernelManager.h" -#include "Backend/TensorMemory.h" -#include "Compiler/KernelGen.h" - -namespace TensorFrost { - -using namespace std; - -extern std::string kernelCompileOptions; - -void CompileAndLoadKernelModule(Program* program, size_t program_id); - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/Backends/CPU/KernelManager.h b/TensorFrost/Backend/Backends/CPU/KernelManager.h deleted file mode 100644 index eb03b1d4..00000000 --- a/TensorFrost/Backend/Backends/CPU/KernelManager.h +++ /dev/null @@ -1,41 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "../../KernelManager.h" - -namespace TensorFrost { - -class CpuKernelManager : public KernelManager { - unordered_map kernel_functions; - public: - - void AddKernelFunction(Kernel* kernel, cpu_dispatch_func* func) { - kernel_functions[kernel->kernel_id_] = func; - } - - cpu_dispatch_func* GetKernel(size_t id) { - return kernel_functions[id]; - } - - void DispatchKernel(TFDispatchInfo info) override - { - cpu_dispatch_func* func = kernel_functions[info.kernel_id]; - //get memory pointers - uint32_t** memory = new uint32_t*[info.read_write_count]; - for (size_t i = 0; i < info.read_write_count; i++) { - memory[i] = ((TFCPUBuffer*)info.read_write_tensors[i].buffer)->GetNative(); - } - func(info.variables, memory, (uint)info.work_group_count); - delete[] memory; - } -}; - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/Backends/CPU/Memory.h b/TensorFrost/Backend/Backends/CPU/Memory.h deleted file mode 100644 index cbc2375a..00000000 --- a/TensorFrost/Backend/Backends/CPU/Memory.h +++ /dev/null @@ -1,60 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "../../TensorMemory.h" - -namespace TensorFrost { - -using namespace std; - -class TFCPUBuffer: public TFBufferTemplate { -public: - uint32_t* data; - - TFCPUBuffer(size_t size): TFBufferTemplate(size) { - data = new uint32_t[size]; - } - - void UpdateName(const char* new_name) override { - if(new_name != nullptr) { - name = new_name; - } - } - - void SetDataAtOffset(size_t offset, const vector& data) override { - memcpy(this->data + offset, data.data(), data.size() * sizeof(uint32_t)); - } - - void GetDataAtOffset(size_t offset, size_t size, uint32_t* data) override { - memcpy(data, this->data + offset, size * sizeof(uint32_t)); - } - - uint32_t* GetNative() const { - return data; - } - - ~TFCPUBuffer() { - delete[] data; - } -}; - -class CpuMemoryManager : public TensorMemoryManager { - public: - TFBuffer* CreateBuffer(size_t size) override { - return new TFCPUBuffer(size); - } - - void DeleteBuffer(TFBuffer* buffer) override { - delete (TFCPUBuffer*)buffer; - } -}; - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/Backends/OpenGL/KernelManager.h b/TensorFrost/Backend/Backends/OpenGL/KernelManager.h deleted file mode 100644 index a9d5929d..00000000 --- a/TensorFrost/Backend/Backends/OpenGL/KernelManager.h +++ /dev/null @@ -1,183 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "../../KernelManager.h" - -namespace TensorFrost { - -class OpenGLKernelManager : public KernelManager { - unordered_map kernel_map; - const int WORK_GROUP_SIZE = 256; - GLuint ubo; - public: - OpenGLKernelManager() { - glGenBuffers(1, &ubo); - //allocate sizeof(uint32_t) * 32 bytes - glBindBuffer(GL_UNIFORM_BUFFER, ubo); - glBufferData(GL_UNIFORM_BUFFER, sizeof(uint32_t) * 32, nullptr, GL_DYNAMIC_DRAW); - glBindBuffer(GL_UNIFORM_BUFFER, 0); - } - - void UpdateUBO(const uint32_t* data, size_t size) { - glBindBuffer(GL_UNIFORM_BUFFER, ubo); - glBufferSubData(GL_UNIFORM_BUFFER, 0, sizeof(uint32_t) * size, data); - glBindBuffer(GL_UNIFORM_BUFFER, 0); - } - - GLuint createComputeShader(const std::string& source) { - GLuint shader = glCreateShader(GL_COMPUTE_SHADER); - const char* src = source.c_str(); - glShaderSource(shader, 1, &src, nullptr); - glCompileShader(shader); - - // Check for compilation errors - GLint success; - glGetShaderiv(shader, GL_COMPILE_STATUS, &success); - if (!success) { - GLchar infoLog[512]; - glGetShaderInfoLog(shader, 512, nullptr, infoLog); - throw std::runtime_error("TensorFrost: Error compiling shader: " + source + "\n" + std::string(infoLog)); - } - - return shader; - } - - GLuint createShaderProgram(const std::string& computeShaderSource) { - GLuint computeShader = createComputeShader(computeShaderSource); - GLuint program = glCreateProgram(); - glAttachShader(program, computeShader); - glLinkProgram(program); - - // Check for linking errors - GLint success; - glGetProgramiv(program, GL_LINK_STATUS, &success); - if (!success) { - GLchar infoLog[512]; - glGetProgramInfoLog(program, 512, nullptr, infoLog); - throw std::runtime_error("Error linking program: " + std::string(infoLog)); - } - - glDeleteShader(computeShader); - return program; - } - - void CompileKernel(Kernel* kernel) - { - #ifndef NDEBUG //print out source if debug is enabled - cout << "Compiling kernel \n" << kernel->full_generated_code_ << endl; - #endif - try { - GLuint program = createShaderProgram(kernel->full_generated_code_); - kernel_map[kernel->kernel_id_] = program; - } catch (const std::exception& e) { - string error_message = "Error compiling kernel " + to_string(kernel->kernel_id_) + "\n" + e.what(); - error_message = error_message + "\n" + kernel->full_generated_code_; - throw std::runtime_error(error_message); - } - } - - //Get uniform location - GLint getUniformLocation(GLuint program, const std::string& name) { - GLint location = glGetUniformLocation(program, name.c_str()); - if (location == -1) { - throw std::runtime_error("OpenGL error: uniform " + name + " not found"); - } - return location; - } - - //Get attribute location - GLint getAttribLocation(GLuint program, const std::string& name) { - GLint location = glGetAttribLocation(program, name.c_str()); - if (location == -1) { - throw std::runtime_error("OpenGL error: attribute " + name + " not found"); - } - return location; - } - - void DispatchKernel(TFDispatchInfo info) override - { - GLuint program = kernel_map[info.kernel_id]; - Kernel* kernel = GetKernel(info.kernel_id); - glUseProgram(program); - - #ifndef NDEBUG - // validate the program - glValidateProgram(program); - GLint success; - glGetProgramiv(program, GL_VALIDATE_STATUS, &success); - if (!success) { - GLchar infoLog[512]; - glGetProgramInfoLog(program, 512, nullptr, infoLog); - throw std::runtime_error("OpenGL error: program validation failed: " + std::string(infoLog)); - } - #endif - - // Set uniforms - if (info.read_write_count == 0) throw std::runtime_error("No tensors provided to kernel"); - - //bind all memory buffers - for (size_t i = 0; i < info.read_write_count; i++) { - GLuint buffer = ((TFOpenGLBuffer*)info.read_write_tensors[i].buffer)->GetNative(); - glBindBufferBase(GL_SHADER_STORAGE_BUFFER, (GLuint)i, buffer); - } - - if (info.variable_count > 0) - { - UpdateUBO(info.variables, info.variable_count); - } - - // Bind the UBO - glBindBufferBase(GL_UNIFORM_BUFFER, 0, ubo); - - // Dispatch the kernel - glDispatchCompute((GLuint)info.work_group_count, 1, 1); - - // Wait for the kernel to finish - glMemoryBarrier(GL_SHADER_STORAGE_BARRIER_BIT); - - // Unbind the memory buffers - for (size_t i = 0; i < info.read_write_count; i++) { - glBindBufferBase(GL_SHADER_STORAGE_BUFFER, (GLuint)i, 0); - } - - // Unbind the program - glUseProgram(0); - - // Check for errors - GLenum error = glGetError(); - if (error != GL_NO_ERROR) { - throw std::runtime_error("OpenGL error: " + std::to_string(error)); - } - } - - void FreeKernel(int kernel_id) - { - GLuint program = kernel_map[kernel_id]; - glDeleteProgram(program); - kernel_map.erase(kernel_id); - } - - void FreeAllKernels() - { - for (auto& kernel : kernel_map) { - glDeleteProgram(kernel.second); - } - kernel_map.clear(); - } - - ~OpenGLKernelManager() - { - FreeAllKernels(); - glDeleteBuffers(1, &ubo); - } -}; - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/Backends/OpenGL/Memory.h b/TensorFrost/Backend/Backends/OpenGL/Memory.h deleted file mode 100644 index 07103023..00000000 --- a/TensorFrost/Backend/Backends/OpenGL/Memory.h +++ /dev/null @@ -1,102 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "../../TensorMemory.h" - -namespace TensorFrost { - -class TFOpenGLBuffer: public TFBufferTemplate { - GLuint buffer; - - const size_t max_cache_size = 16384; - uint32_t* cached_data = nullptr; - - public: - TFOpenGLBuffer(size_t size): TFBufferTemplate(size) { - GLint maxsize; - glGetIntegerv(GL_MAX_SHADER_STORAGE_BLOCK_SIZE, &maxsize); - - if (size * sizeof(uint32_t) > maxsize) { - throw std::runtime_error("SSBO memory size exceeded, max size is " + std::to_string(maxsize)); - } - - glCreateBuffers(1, &buffer); - glNamedBufferStorage(buffer, size * sizeof(uint32_t), nullptr, GL_DYNAMIC_STORAGE_BIT); - } - - void UpdateName(const char* new_name) override { - if(new_name != nullptr && strcmp(new_name, "") != 0) { - name = new_name; - glObjectLabel(GL_BUFFER, buffer, -1, name); - } - } - - void UpdateCache(size_t data_offset, size_t data_size, const uint32_t* data) { - return; - if(data_offset == 0 && data_size == used_size && data_size <= max_cache_size) { - if(cached_data == nullptr) { - cached_data = new uint32_t[data_size]; - } - memcpy(cached_data, data, data_size * sizeof(uint32_t)); - up_to_date = true; - } - } - - void SetDataAtOffset(size_t offset, const vector& data) override { - glBindBuffer(GL_SHADER_STORAGE_BUFFER, buffer); - glBufferSubData(GL_SHADER_STORAGE_BUFFER, offset * sizeof(uint32_t), - data.size() * sizeof(uint32_t), data.data()); - glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); - - UpdateCache(offset, data.size(), data.data()); - } - - void GetDataAtOffset(size_t offset, size_t size, uint32_t* data) override { - if(up_to_date) { - memcpy(data, cached_data + offset, size * sizeof(uint32_t)); - return; - } - - glBindBuffer(GL_SHADER_STORAGE_BUFFER, buffer); - glGetBufferSubData(GL_SHADER_STORAGE_BUFFER, offset * sizeof(uint32_t), - size * sizeof(uint32_t), data); - glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); - - UpdateCache(offset, size, data); - } - - GLuint GetNative() const { - return buffer; - } - - ~TFOpenGLBuffer() { - glDeleteBuffers(1, &buffer); - if(cached_data != nullptr) { - delete[] cached_data; - } - } -}; - -class OpenGLMemoryManager : public TensorMemoryManager { - public: - OpenGLMemoryManager() {} - - TFBuffer* CreateBuffer(size_t size) override { - return new TFOpenGLBuffer(size); - } - - void DeleteBuffer(TFBuffer* buffer) override { - delete (TFOpenGLBuffer*)buffer; - } -}; - - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/Backends/OpenGL/OpenGL.cpp b/TensorFrost/Backend/Backends/OpenGL/OpenGL.cpp deleted file mode 100644 index 9478b3cd..00000000 --- a/TensorFrost/Backend/Backends/OpenGL/OpenGL.cpp +++ /dev/null @@ -1,384 +0,0 @@ -#include "OpenGL.h" - -namespace TensorFrost { - -string vertex_shader = R"( -#version 430 core -out vec2 texCoords; - -void main() { - const vec2 pos[6] = vec2[](vec2(-1.0, -1.0), vec2(1.0, -1.0), vec2(1.0, 1.0), - vec2(-1.0, -1.0), vec2(1.0, 1.0), vec2(-1.0, 1.0)); - gl_Position = vec4(pos[gl_VertexID], 0.0, 1.0); - texCoords = pos[gl_VertexID] * 0.5 + 0.5; -} -)"; - -string fragment_shader = R"( -#version 430 core -in vec2 texCoords; -out vec4 FragColor; - -layout(std430, binding = 0) buffer memory { - uint mem[]; -}; - -uniform int offset; -uniform int width; -uniform int height; - -void main() { - ivec2 pixel = ivec2(texCoords.x * width, texCoords.y * height); - int pixel_offset = pixel.y * width + pixel.x; - int cur_offset = pixel_offset * 3 + offset; - float r = uintBitsToFloat(mem[cur_offset]); - float g = uintBitsToFloat(mem[cur_offset + 1]); - float b = uintBitsToFloat(mem[cur_offset + 2]); - FragColor = vec4(r, g, b, 1.0); -} -)"; - -GLuint CreateShader(const string& source, GLenum type) { - GLuint shader = glCreateShader(type); - const char* src = source.c_str(); - glShaderSource(shader, 1, &src, nullptr); - glCompileShader(shader); - - int success; - glGetShaderiv(shader, GL_COMPILE_STATUS, &success); - if (!success) { - char infoLog[512]; - glGetShaderInfoLog(shader, 512, nullptr, infoLog); - throw std::runtime_error("Failed to compile shader: " + string(infoLog)); - } - - return shader; -} - -GLuint CreateProgram(const string& vertexSource, const string& fragmentSource) { - GLuint vertexShader = CreateShader(vertexSource, GL_VERTEX_SHADER); - GLuint fragmentShader = CreateShader(fragmentSource, GL_FRAGMENT_SHADER); - - GLuint program = glCreateProgram(); - glAttachShader(program, vertexShader); - glAttachShader(program, fragmentShader); - glLinkProgram(program); - - int success; - glGetProgramiv(program, GL_LINK_STATUS, &success); - - return program; -} - -GLuint quad_program = 0; -GLFWwindow* global_window = nullptr; -bool window_open = false; -ImGuiIO* io; - -void GLAPIENTRY DebugCallback(GLenum source, GLenum type, GLuint id, - GLenum severity, GLsizei length, - const GLchar* message, const void* userParam) { - // Output or log the debug message - std::cerr << "OpenGL Debug: source=" << source << ", type=" << type - << ", id=" << id << ", severity=" << severity << endl; - std::cerr << "Message: " << message << endl << endl; -} - - -// Window close callback function -void WindowCloseCallback(GLFWwindow* window) { - window_open = false; - - // Instead of closing, hide the window - glfwHideWindow(window); - - // Prevent the window from closing - glfwSetWindowShouldClose(window, GLFW_FALSE); -} - -void WindowSizeCallback(GLFWwindow* window, int width, int height) { - glViewport(0, 0, width, height); -} - -void ImguiNewFrame() { - if (global_window == nullptr) { - throw std::runtime_error("Window: OpenGL not initialized"); - } - - ImGui_ImplOpenGL3_NewFrame(); - ImGui_ImplGlfw_NewFrame(); - ImGui::NewFrame(); -} - -void ImguiRender() { - if (global_window == nullptr) { - throw std::runtime_error("Window: OpenGL not initialized"); - } - - ImGui::Render(); - ImGui_ImplOpenGL3_RenderDrawData(ImGui::GetDrawData()); -} - -std::string last_error; - -void error_callback(int error, const char* description) { - last_error = std::string(description); -} - -//#define OPENGL_DEBUG - -void StartOpenGL() { - glfwSetErrorCallback(error_callback); - - if (!glfwInit()) { - throw std::runtime_error("Failed to initialize GLFW: " + last_error); - } - - // Make window invisible - glfwWindowHint(GLFW_VISIBLE, GLFW_FALSE); - global_window = glfwCreateWindow(800, 600, "TensorFrost", nullptr, nullptr); - - if (!global_window) { - int code = glfwGetError(nullptr); - glfwTerminate(); - throw std::runtime_error("Failed to create window (error " + std::to_string(code) + "): " + last_error); - } - - glfwMakeContextCurrent(global_window); - - - int version = gladLoadGL(glfwGetProcAddress); - if (version == 0) { - throw std::runtime_error("Failed to load OpenGL"); - } - - // Successfully loaded OpenGL - printf("Loaded OpenGL %d.%d\n", GLAD_VERSION_MAJOR(version), - GLAD_VERSION_MINOR(version)); - - // Print the renderer string, which usually contains the GPU's name - const GLubyte* renderer = glGetString(GL_RENDERER); - const GLubyte* vendor = glGetString(GL_VENDOR); - printf("Renderer: %s\nVendor: %s\n", renderer, vendor); - - // Print the max buffer size - GLint bufferSize; - glGetIntegerv(GL_MAX_SHADER_STORAGE_BLOCK_SIZE, &bufferSize); - printf("Max available buffer size, MB: %d\n", bufferSize / 1024 / 1024); - - // Print max SSBO bindings - GLint max_ssbo_bindings; - glGetIntegerv(GL_MAX_SHADER_STORAGE_BUFFER_BINDINGS, &max_ssbo_bindings); - printf("Maximum SSBO bindings supported: %d\n", max_ssbo_bindings); - - // Enable debug output - #if !defined(NDEBUG) || defined(OPENGL_DEBUG) - if (GLAD_GL_KHR_debug) { - glEnable(GL_DEBUG_OUTPUT); - glEnable(GL_DEBUG_OUTPUT_SYNCHRONOUS); - glDebugMessageCallback(DebugCallback, nullptr); - glDebugMessageControl(GL_DONT_CARE, GL_DONT_CARE, GL_DEBUG_SEVERITY_HIGH, 0, nullptr, - GL_TRUE); - } - #endif - - // Create the shader program - quad_program = CreateProgram(vertex_shader, fragment_shader); - - glfwSetWindowCloseCallback(global_window, WindowCloseCallback); - glfwSetWindowSizeCallback(global_window, WindowSizeCallback); - - IMGUI_CHECKVERSION(); - ImGui::CreateContext(); - io = &ImGui::GetIO(); (void)*io; - io->ConfigFlags |= ImGuiConfigFlags_NavEnableKeyboard; - io->ConfigFlags |= ImGuiConfigFlags_NavEnableGamepad; - - // Setup Dear ImGui style - ImGui::StyleColorsDark(); - - ImGui_ImplGlfw_InitForOpenGL(global_window, true); - ImGui_ImplOpenGL3_Init("#version 430"); - - ImguiNewFrame(); -} - -void StopOpenGL() { - if (global_window == nullptr) { - throw std::runtime_error("OpenGL not initialized"); - } - - glfwDestroyWindow(global_window); - glfwTerminate(); - - glDeleteProgram(quad_program); - global_window = nullptr; - - // Cleanup ImGui - ImGui_ImplOpenGL3_Shutdown(); - ImGui_ImplGlfw_Shutdown(); - ImGui::DestroyContext(); -} - -void ShowWindow(int width, int height, const char* title) { - window_open = true; - - if(global_window == nullptr) { - throw std::runtime_error("Window: OpenGL not initialized"); - } - - glfwSetWindowSize(global_window, width, height); - glfwSetWindowTitle(global_window, title); - glfwShowWindow(global_window); - - //reset viewport - glViewport(0, 0, width, height); -} - -void HideWindow() { - if(global_window == nullptr) { - throw std::runtime_error("Window: OpenGL not initialized"); - } - - window_open = false; - glfwHideWindow(global_window); -} - -void Finish() { - glFinish(); -} - -void RenderFrame(const TFTensor* tensor) { - if (global_window == nullptr) { - throw std::runtime_error("RenderFrame: OpenGL not initialized"); - } - - // Clear the screen - glClearColor(0.0f, 0.0f, 0.0f, 1.0f); - glClear(GL_COLOR_BUFFER_BIT); - - if(tensor != nullptr) - { - //check if tensor is 2d + 3 channels - if (tensor->dim != 3 || tensor->shape[2] != 3) { - throw std::runtime_error("Window: Render tensor must be of shape (height, width, 3)"); - } - - //check if tensor is float32 (TODO: use int8 instead) - if (tensor->format != TFTypeFloat32) { - throw std::runtime_error("Window: Render tensor must be of type float32"); - } - - GLuint ssbo = ((TFOpenGLBuffer*)tensor->buffer)->GetNative(); - glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, ssbo); - - glUseProgram(quad_program); - - // Set the uniforms - int offset = 0; - int width = (int)tensor->shape[1]; - int height = (int)tensor->shape[0]; - glUniform1i(glGetUniformLocation(quad_program, "offset"), offset); - glUniform1i(glGetUniformLocation(quad_program, "width"), width); - glUniform1i(glGetUniformLocation(quad_program, "height"), height); - - // Draw the quad - glDrawArrays(GL_TRIANGLES, 0, 6); - - glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, 0); - glUseProgram(0); - } - - ImguiRender(); - - // Swap the buffers - glfwSwapBuffers(global_window); - glfwPollEvents(); - - ImguiNewFrame(); -} - -bool WindowShouldClose() { return !window_open; } - -pair GetMousePosition() { - double x, y; - glfwGetCursorPos(global_window, &x, &y); - return {x, y}; -} - -pair GetWindowSize() { - int width, height; - glfwGetWindowSize(global_window, &width, &height); - return {width, height}; -} - -bool IsMouseButtonPressed(int button) { - //if pressed in imgui, return false - if (io->WantCaptureMouse) { - return false; - } - return glfwGetMouseButton(global_window, button) == GLFW_PRESS; -} - -bool IsKeyPressed(int key) { - return glfwGetKey(global_window, key) == GLFW_PRESS; -} - -void ImGuiBegin(std::string name) { - ImGui::Begin(name.c_str()); -} - -void ImGuiEnd() { - ImGui::End(); -} - -void ImGuiText(const std::string& text) { - ImGui::Text("%s", text.c_str()); -} - -void ImGuiSlider(std::string text, int* value, int min, int max) { - ImGui::SliderInt(text.c_str(), value, min, max); -} - -void ImGuiSlider(std::string text, float* value, float min, float max) { - ImGui::SliderFloat(text.c_str(), value, min, max, "%.5f"); -} - -bool ImGuiButton(std::string text) { - return ImGui::Button(text.c_str()); -} - -void ImGuiPlotLines(const char *label, const float *values, int values_count, int values_offset, - const char *overlay_text, float scale_min, float scale_max, ImVec2 graph_size, int stride) { - ImGui::PlotLines(label, values, values_count, values_offset, overlay_text, scale_min, scale_max, graph_size, stride); -} - -bool ImGuiCheckbox(std::string text, bool* value) { - return ImGui::Checkbox(text.c_str(), value); -} - -void ImGuiScaleAllSizes(float scale) { - ImGuiStyle& style = ImGui::GetStyle(); - style.ScaleAllSizes(scale); -} - -void ImGuiAddBackgroundText(const std::string &text, const ImVec2 &pos, const ImVec4 &color) { - ImDrawList* draw_list = ImGui::GetBackgroundDrawList(); - draw_list->AddText(pos, ImGui::GetColorU32(color), text.c_str()); -} - -void ImGuiColorPicker3(const std::string &text, float *color) { - ImGui::ColorPicker3(text.c_str(), color); -} - -void ImGuiColorPicker4(const std::string &text, float *color) { - ImGui::ColorPicker4(text.c_str(), color); -} - -void StartDebugRegion(const std::string& name) { - glPushDebugGroup(GL_DEBUG_SOURCE_APPLICATION, 0, (GLsizei)name.size(), name.c_str()); -} - -void EndDebugRegion() { glPopDebugGroup(); } - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/Backends/OpenGL/OpenGL.h b/TensorFrost/Backend/Backends/OpenGL/OpenGL.h deleted file mode 100644 index 8f7c0561..00000000 --- a/TensorFrost/Backend/Backends/OpenGL/OpenGL.h +++ /dev/null @@ -1,54 +0,0 @@ -#pragma once - -#include "imgui.h" -#include "imgui_impl_glfw.h" -#include "imgui_impl_opengl3.h" - -#include "glad/gl.h" -#include "GLFW/glfw3.h" - -#include "Memory.h" -#include "KernelManager.h" - -namespace TensorFrost { - -void StartOpenGL(); - -void StopOpenGL(); - -void ShowWindow(int width, int height, const char* title); -void HideWindow(); - -void Finish(); - -void RenderFrame(const TFTensor* tensor); - -bool WindowShouldClose(); - -pair GetMousePosition(); -pair GetWindowSize(); - -bool IsMouseButtonPressed(int button); -bool IsKeyPressed(int key); - -void ImGuiBegin(std::string name); -void ImGuiEnd(); - -void ImGuiText(const std::string& text); -void ImGuiSlider(std::string text, int* value, int min, int max); -void ImGuiSlider(std::string text, float* value, float min, float max); -bool ImGuiCheckbox(std::string text, bool* value); -bool ImGuiButton(std::string text); - -void ImGuiPlotLines(const char* label, const float* values, int values_count, int values_offset = 0, const char* overlay_text = NULL, float scale_min = FLT_MAX, float scale_max = FLT_MAX, ImVec2 graph_size = ImVec2(0, 0), int stride = sizeof(float)); - -void ImGuiScaleAllSizes(float scale); - -void ImGuiAddBackgroundText(const std::string& text, const ImVec2& pos, const ImVec4& color); -void ImGuiColorPicker3(const std::string& text, float* color); -void ImGuiColorPicker4(const std::string& text, float* color); - -void StartDebugRegion(const std::string& name); -void EndDebugRegion(); - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/CMakeLists.txt b/TensorFrost/Backend/CMakeLists.txt new file mode 100644 index 00000000..3289f00e --- /dev/null +++ b/TensorFrost/Backend/CMakeLists.txt @@ -0,0 +1,42 @@ +set(TF_BACKEND_INC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include) +set(TF_BACKEND_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src) + +file(GLOB_RECURSE TENSORFROST_BACKEND_SOURCE_LIST CONFIGURE_DEPENDS + ${TF_BACKEND_SRC_DIR}/*.cpp) + +file(GLOB_RECURSE TENSORFROST_BACKEND_HEADER_LIST CONFIGURE_DEPENDS + ${TF_BACKEND_INC_DIR}/*.h + ${TF_BACKEND_INC_DIR}/*.hpp) + +add_library(TensorFrostBackend STATIC + ${TENSORFROST_BACKEND_SOURCE_LIST} + ${TENSORFROST_BACKEND_HEADER_LIST}) + +target_include_directories(TensorFrostBackend + PUBLIC + ${TF_BACKEND_INC_DIR} + $ENV{VULKAN_SDK}/Include + ${CMAKE_SOURCE_DIR}/external/renderdoc + ${CMAKE_SOURCE_DIR}/external/imgui + ${CMAKE_SOURCE_DIR}/external/imgui/backends) + +target_link_libraries(TensorFrostBackend + PUBLIC + Vulkan::Vulkan + glfw + $<$:${SLANG_LIB_DEBUG}> + $<$>:${SLANG_LIB_RELEASE}>) + +target_compile_definitions(TensorFrostBackend PUBLIC VULKAN_HPP_DISPATCH_LOADER_DYNAMIC=1) + +target_compile_features(TensorFrostBackend PUBLIC cxx_std_20) + +if (MSVC) + target_compile_options(TensorFrostBackend PRIVATE /wd4804 /wd4805 /wd4018) +endif() + +source_group(TREE ${TF_BACKEND_SRC_DIR} PREFIX "Source Files" + FILES ${TENSORFROST_BACKEND_SOURCE_LIST}) + +source_group(TREE ${TF_BACKEND_INC_DIR} PREFIX "Header Files" + FILES ${TENSORFROST_BACKEND_HEADER_LIST}) diff --git a/TensorFrost/Backend/CodeGen/Generators.cpp b/TensorFrost/Backend/CodeGen/Generators.cpp deleted file mode 100644 index 62bdae9f..00000000 --- a/TensorFrost/Backend/CodeGen/Generators.cpp +++ /dev/null @@ -1,556 +0,0 @@ -#ifndef TENSORFROST_BACKEND_CODEGEN_GENERATORS_CPP -#define TENSORFROST_BACKEND_CODEGEN_GENERATORS_CPP - -#include "Backend/CodeGen/Generators.h" - -namespace TensorFrost { -using namespace std; - -void GenerateKernel(Program* program, Kernel* kernel) { - switch (current_kernel_lang) { - case CodeGenLang::CPP: - GenerateCPPKernel(program, kernel); - return; - case CodeGenLang::HLSL: - GenerateHLSLKernel(program, kernel); - return; - case CodeGenLang::GLSL: - GenerateGLSLKernel(program, kernel); - return; - default: - throw std::runtime_error("Code generation for this language is not implemented yet"); - } -} - -unordered_set forbidden_names = { - "unsigned", "input", "output", "max", "min", "exp", "sin", "cos", "if", "else", "while", "for", "switch", "case", "default", "break", - "this", "true", "false", "null", "new", "delete", "return", "continue", "goto", "try", "catch", "throw", - "const", "static", "extern", "inline", "virtual", "override", "final", "public", "protected", "private", "sample", - "texture", "sampler", "uniform", "varying", "attribute", "in", "out", "inout", "layout", "precision", "highp", - "mediump", "lowp", "noperspective", "flat", "smooth", "centroid", "patch", "sample", "subroutine", "common", - "partition", "active", "asm", "class", "union", "enum", "typedef", "template", "typename", "using", "namespace", - "friend", "sizeof", "alignof", "typeid", "dynamic_cast", "static_cast", "const_cast", "reinterpret_cast", "sizeof", - "alignof", "typeid", "noexcept", "throw", "auto", "register", "explicit", "mutable", "thread_local", - "constexpr", "decltype", "noexcept", "nullptr", "alignas", "and", "and_eq", "bitand", "bitor", "compl", "not", "not_eq", "or", - "sign", "xor", "xor_eq", "bool", "break", "case", "char" -}; - -bool IsForbiddenName(const string& name) { - return forbidden_names.contains(name); -} - -void GenerateNodeNames(const IR& ir) { - int var_index = 0; - int mem_index = 0; - int cluster_index = 0; - Node* curent_cluster = nullptr; - map name_count; - for (auto node = ir.begin(); !node.end(); node.next()) { - if (strip_debug_names && !node->debug_name.empty()) { - static int count = 0; - node->debug_name = "id_" + std::to_string(count++); - } - if (node->parent != curent_cluster) { - cluster_index++; - var_index = 0; - } - string debug = node->debug_name; - if (!debug.empty()) { - // check if the name is already used - if (name_count.contains(debug)) { - name_count[debug]++; - debug = debug + "_" + to_string(name_count[debug]); - } - else { - name_count[debug] = 1; - } - if (IsForbiddenName(debug) ) { - debug = debug + "0"; - } - node->var_name = debug; - } - else - { - if (node->name == "memory") { - node->var_name = debug + "m" + to_string(mem_index); - mem_index++; - } else { - node->var_name = - debug + "v" + to_string(cluster_index) + "_" + to_string(var_index); - var_index++; - } - } - - curent_cluster = node->parent; - } -} - - -string GetBufferDeclarations(Kernel *kernel, function get_name) { - map memory_bindings = kernel->GetMemoryBindings(); - - vector buffer_declarations = vector(memory_bindings.size()); - for (auto& buffer : memory_bindings) { - Node* mem_node = buffer.first; - size_t binding = buffer.second; - string name = mem_node->var_name; - string type_name = "uint"; - buffer_declarations[binding] = get_name(name, type_name, binding); - } - - string final_source; - for (auto& decl : buffer_declarations) { - final_source += decl; - } - - return final_source; -} - -string GetGroupBufferDeclarations(Kernel *kernel, function get_shared_name) { - string final_source; - - // add group memory declarations - for (auto& mem : kernel->group_memory) { - string name = mem->var_name; - string type_name = type_names[mem->format.type]; - //TODO: add support for non 32 bit types - final_source += get_shared_name(name, type_name, mem->data[0]); - } - - return final_source; -} - -string ReadVariable(Node* node) { - if (node->name == "const") { - return to_string(node->data[0]); - } - if (node->name == "memory") { - return "mem[" + node->var_name + "]"; - } - return node->var_name; -} - -string GetNodeName(const Node* node, bool compact) { - string name = node->var_name; - if (compact) { - if (node->name == "const" && !node->flags.has(NodeProp::Modified)) { - name = node->GetTensor()->GetConstantString(); - } - } - else { - if (node->name == "const") { - name = node->var_name + "(" + node->GetTensor()->GetConstantString() + ")"; - } - } - if (name.empty()) { - name = node->name + "_" + to_string(node->debug_index); - } - return name; -} - -#ifdef HAS_FORMAT -string format_float(double x) { - std::string s = std::format("{}", x); - if (s.find('.') == std::string::npos && s.find('e') == std::string::npos) { - s += '.'; - } - return s + 'f'; -} -#else -string format_float(double value) { - std::ostringstream out; - - // Determine when to use scientific notation vs fixed - bool use_scientific = std::abs(value) < 1e-4 || std::abs(value) > 1e6; - //and if not a zero value - use_scientific = use_scientific && value != 0.0; - if (use_scientific) { - out << std::scientific; // Use scientific notation for very small or large - // numbers - } else { - out << std::fixed; // Use fixed notation for moderate values - } - - out << std::setprecision(7) << value; - - // Convert to string - std::string str = out.str(); - - // Remove trailing zeros and potentially unnecessary decimal point - size_t endpos = str.find_last_not_of('0'); - if (endpos != std::string::npos) { - str = str.substr(0, endpos + 1); - } - if (str.back() == '.') { - str.pop_back(); - } - - // remove all zeros before "e" - size_t epos = str.find('e'); - if (epos != std::string::npos) { - size_t startpos = str.find_last_not_of('0', epos - 1); - if (startpos != std::string::npos) { - str = str.substr(0, startpos + 1) + str.substr(epos); - } - } - - if (str.find('.') == string::npos && str.find('e') == string::npos) { - str += '.'; - } - - // add a zero digit after the decimal point if the next character is not a - // digit - size_t dotpos = str.find('.'); - if (dotpos != std::string::npos && !isdigit(str[dotpos + 1])) { - //add a zero digit after the decimal point - str.insert(dotpos + 1, "0"); - } - - return str + 'f'; -} -#endif - -inline string Tensor::GetConstantString() const { - if (node_->name == "const" || node_->name == "dim_id") { - switch (node_->format.type) { - case TFType::Float: - return format_float(AsFloat(node_->data[0])); - case TFType::Int: - return to_string(AsInt(node_->data[0])); - case TFType::Uint: - return to_string(node_->data[0]) + "u"; - case TFType::Bool: - return node_->data[0] == 0 ? "false" : "true"; - default: - throw std::runtime_error("Unsupported constant type"); - } - } else { - return ""; - } -} - -void CodeGenerator::GenerateKernelCode(Kernel* kernel_) { - kernel = kernel_; - variables = kernel->variables; - read_write_bindings = kernel->read_write_memory; - read_only_bindings = kernel->read_only_memory; - GenerateCode(kernel->root); -} - -void CodeGenerator::GenerateCode(const Node* root) { - int variable_index = 0; - int memory_index = 0; - int prev_depth = 0; - // Translate each operation into HLSL - for (auto node = NodeIterator(root); !node.end(); node->name == "kernel" ? node.forward() : node.next()) { - string name = node->var_name; - - int depth = node.depth() - 1; - if (depth != prev_depth) { - // add scope brackets - if (depth < prev_depth) { - for (int i = prev_depth - 1; i >= depth; i--) { - lines.push_back(new Line(i, "}")); - } - } else if (depth > prev_depth) { - for (int i = prev_depth; i < depth; i++) { - lines.push_back(new Line(i, "{")); - } - } - } - - Line* line = nullptr; - if (custom_generated_code_.contains(*node)) { - line = new Line(*node, "", custom_generated_code_[*node], ";", ""); - } else { - // get node arguments - line = GenerateLine(*node); - } - - if (line == nullptr) { - continue; - } - - line->indent = depth; - lines.push_back(line); - - for (auto additional: additional_lines) { - lines.push_back(new Line(depth, additional)); - } - additional_lines.clear(); - - prev_depth = depth; - } - - // add closing brackets - for (int i = prev_depth - 1; i >= 0; i--) { - lines.push_back(new Line(i, "}")); - } - - //remove lines - unordered_set remove_lines; - for (auto& line : lines) { - if (lines_to_remove.contains(line->node)) { - remove_lines.insert(line); - } - } - - for (auto& line : remove_lines) { - lines.erase(std::remove(lines.begin(), lines.end(), line), lines.end()); - } -} - - -string CodeGenerator::AssembleString() { - string code; - int indent = 0; - for (auto& line : lines) { - for (int i = 0; i < line->indent; i++) { - code += " "; - } - code += line->left; - code += line->expression; - code += line->right; - code += "\n"; - } - return code; -} - -Line* CodeGenerator::GenerateLine(Node* node) { - ArgumentManager& args = node->args; - if (kernel) RegenerateNodeName(node); - GenerateArgumentNames(args); - const Operation* op = node->op; - - string name = node->var_name; - - // get output type - TFDataFormat output_format = node->format; - //TODO: add support for non 32 bit types - - // generate line - string left = ""; - string expression = ""; - string right = ""; - bool needs_paranthesis = false; - - if (op->HasAllTypes(OpProp::Special)) { - int dims = args.Count(ArgType::Shape); - - string shape_arg = "{"; - - for (int j = 0; j < dims; j++) { - if (j != 0) { - shape_arg += ", "; - } - Node* shape_node = args.Get(ArgType::Shape, j); - - shape_arg += "(uint)" + args.Name(ArgType::Shape, dims - j - 1); - } - - shape_arg += "}"; - - if (op->name_ == "loop") { - left += GenerateLoop(&args, name); - } else if (op->name_ == "if") { - left += GenerateIf(&args); - } else if (op->name_ == "memory") { - // if input memory type then just take the input and store it in the - // output - if (node->flags.has(NodeProp::InputMemory)) { - left += "tf.check_tensor(" + node->var_name+ ", \"" + node->var_name + "\", " + shape_arg + ", " + DataFormatNames[output_format] + ")"; - right += ";"; - } - // if any other memory type - allocate it - else { - left += "TFTensor " + node->var_name + " = "; - expression += "tf.allocate(\"" + node->var_name + "\", " + shape_arg + ", " + DataFormatNames[output_format] + ")"; - right += ";"; - } - } else if (op->name_ == "deallocate") { - left = "tf.deallocate(" + args.Name(ArgType::Memory) + ")"; - right = ";"; - } else if (op->name_ == "input_shape") { - left = "int " + node->var_name + " = "; - expression = ir->input_memory_map[(int)node->flags.get(NodeProp::InputShapeMemory)]->var_name + ".shape[" + to_string((int)node->flags.get(NodeProp::InputShapeDim)) + "]"; - right = ";"; - } else if(op->HasAllTypes(OpProp::MemoryReuse)) { - left = "TFTensor " + node->var_name + " = "; - expression = "tf." + op->code_ + "(" + args.Name(ArgType::Memory) + ", \"" + node->var_name + "\", " + shape_arg + ", " + DataFormatNames[output_format] + ")"; - right = ";"; - } else if(op->HasAllTypes(OpProp::Debug)) { - left = "tf." + op->code_ + "(\"" + node->debug_name + "\""; - if (args.Has(ArgType::Input)) { - left += ", " + args.Name(ArgType::Input); - } - left += ")"; - right = ";"; - } else if(op->name_ == "local_memory") { - left = type_names[output_format.type] + " " + name + "[" + to_string(node->data[0]) + "]"; - right = ";"; - } else if(op->name_ == "group_memory") { - left = ""; - //just leave as comment, actual declaration is done outside of the main body of the kernel - right = "//" + type_names[output_format.type] + " " + name + "[" + to_string(node->data[0]) + "]"; - } - } else if (op->HasAllTypes(OpProp::MemoryOp)) { - string address; - - if (kernel) { - address = "0"; - // if has index (not a scalar) - if (args.Has(ArgType::Index)) { - address = args.Name(ArgType::Index); - } - - bool is_local = node->flags.has(NodeProp::LocalMemoryOp); - string memory_name = args.Name(ArgType::Memory) + (is_local ? "" : "_mem"); - string memory_expression = memory_name + "[" + address + "]"; - TFType memory_type = is_local ? node->format.type : Uint; - string memory_type_name = type_names[memory_type]; - - if (op->name_ == "load") { - string output_type_name = type_names[output_format.type]; - left += output_type_name + " " + name + " = "; - expression += - (output_format.type == memory_type) - ? memory_expression - : TypeReinterpret(output_type_name, memory_expression); - right += ";"; - } else if (op->name_ == "store") { - expression += memory_expression + " = "; - expression += - (output_format.type == memory_type) - ? args.Name(ArgType::Input) - : TypeReinterpret(memory_type_name, args.Name(ArgType::Input)); - right += ";"; - } else if (op->HasAllTypes(OpProp::Scatter)) { - if (output_format.type != None) { - left += type_names[output_format.type] + " " + name + " = "; - } - string output_type_name = type_names[output_format.type]; - string input_type_name = type_names[args.Format(ArgType::Input).type]; - expression += GenerateAtomicOp(op->name_, input_type_name, - output_type_name, address, - args.Name(ArgType::Input), name, memory_name); - right += ";"; - } - } else { - string tensor_name = args.Name(ArgType::Memory); - string address = "0"; - if (args.Has(ArgType::Index)) { - address = args.Name(ArgType::Index); - } - - if (op->name_ == "load") { - //do readback - string output_type_name = type_names[output_format.type]; - left += output_type_name + " " + name + " = "; - string memory_expression = GetName("tf.read") + "(" + tensor_name + ", " + address + ")"; - expression += (output_format.type == Uint) - ? memory_expression - : TypeReinterpret(output_type_name, memory_expression); - right += ";"; - } else if (op->name_ == "store") { - //do writeback - string memory_expression = GetName("tf.write") + "(" + tensor_name + ", " + address + ", "; - expression += memory_expression + args.Name(ArgType::Input) + ")"; - right += ";"; - } else if (op->HasAllTypes(OpProp::Scatter)) { - throw std::runtime_error("Scatter operation not supported in non-kernel mode"); - } - } - - } else if (op->name_ == "set") { - left += args.Name(ArgType::Memory) + " = "; - expression += args.Name(ArgType::Input); - right += ";"; - } else { - if (output_format.type != None) { - left += type_names[output_format.type] + " " + name + " = "; - } - string line; - string code = op->code_; - switch (op->class_) { - case OpClass::Operator: - args.AddParenthesis(true); - if ((code == "&" || code == "|") && output_format.type == Bool) { - code = code + code; - } - line += args.Name(ArgType::Input, 0) + " " + code + " " + - args.Name(ArgType::Input, 1); - needs_paranthesis = true; - break; - case OpClass::UnaryOperator: - args.AddParenthesis(true); - line += op->code_ + args.Name(ArgType::Input, 0); - needs_paranthesis = true; - break; - case OpClass::Function: - line += GetName(op->code_) + "("; - for (int i = 0; i < args.Count(ArgType::Input); i++) { - if (i != 0) { - line += ", "; - } - line += args.Name(ArgType::Input, i); - } - line += ")"; - break; - case OpClass::Copy: - line += args.Name(ArgType::Input, 0); - needs_paranthesis = true; - break; - case OpClass::Keyword: - line += op->code_; - break; - case OpClass::DimensionIndex: - line += op->code_ + to_string(node->data[0]); - break; - case OpClass::Variable: - line += op->code_; - break; - case OpClass::TypeCast: - line += GenerateTypeCast(&args, DataTypeToString(node->format.type)); - break; - case OpClass::TypeReinterpret: - line += GenerateTypeReinterpret(&args, DataTypeToString(node->format.type)); - break; - case OpClass::Constant: - line += node->GetTensor()->GetConstantString(); - break; - case OpClass::TernaryOperator: - args.AddParenthesis(true); - line += args.Name(ArgType::Input, 0) + " ? " + - args.Name(ArgType::Input, 1) + " : " + - args.Name(ArgType::Input, 2); - needs_paranthesis = true; - break; - default: - throw std::runtime_error("Unknown operation class"); - break; - } - expression += line; - right += ";"; - } - - node_expression[node] = expression; - requires_paranthesis[node] = needs_paranthesis; - - return new Line(node, left, expression, right, name); -} - -string AddIndent(const string& input, const string& indent) { - stringstream ss(input); - string line; - string indentedText; - - while (getline(ss, line)) { - indentedText += indent + line + "\n"; - } - - return indentedText; -} - -} // namespace TensorFrost - -#endif \ No newline at end of file diff --git a/TensorFrost/Backend/CodeGen/Generators.h b/TensorFrost/Backend/CodeGen/Generators.h deleted file mode 100644 index 2de548ea..00000000 --- a/TensorFrost/Backend/CodeGen/Generators.h +++ /dev/null @@ -1,205 +0,0 @@ -#pragma once -#include -#include -#include -#include -#include -#include - -#if __has_include() - #include - #define HAS_FORMAT 1 -#else - #define HAS_FORMAT 0 -#endif - -#include "Compiler/KernelGen.h" -#include "Tensor/Tensor.h" -#include "Backend/Backend.h" - -namespace TensorFrost { - -string GetNodeName(const Node* node, bool compact = false); -string ReadVariable(Node* node); -void GenerateNodeNames(const IR& ir); - -string GetBufferDeclarations(Kernel* kernel, function get_name); -string GetGroupBufferDeclarations(Kernel* kernel, function get_shared_name); -string GetCPPHeader(); -string GetCPPImplementation(); -string GetHLSLHeader(Kernel* kernel); -string GetGLSLHeader(Kernel* kernel); -void GenerateMain(Program* program, map& dispatch_code, int input_count, int output_count); -void GenerateKernel(Program* program, Kernel* kernel); -void GenerateCPPKernel(Program* program, Kernel* kernel); -void GenerateHLSLKernel(Program* program, Kernel* kernel); -void GenerateGLSLKernel(Program* program, Kernel* kernel); -void GenerateCode(Program* program); - -string GetNodeString(const Node* node, bool verbose = false); -string GetOperationListing(const IR&, bool compact = false, - map invalid = {}); - -bool IsForbiddenName(const string& name); - -using ArgumentNames = map; - -class Line { - public: - Node* node; - string left; - string expression; - string right; - string name; - int indent; - - Line(Node* node, string left, string expression, string right, string name, int indent = 0) - : left(left), right(right), name(name), indent(indent), expression(expression), node(node) {} - - Line(int indent, string expression) - : indent(indent), expression(expression), left(""), right(""), name(""), node(nullptr) {} -}; - -class CodeGenerator { -protected: - unordered_map name_map_; - set used_names_; - public: - list lines; - map custom_generated_code_; - - Kernel* kernel = nullptr; - IR* ir = nullptr; - - CodeGenerator(IR* ir) : ir(ir) {} - - void GenerateKernelCode(Kernel *kernel_); - void GenerateCode(const Node* root); - string AssembleString(); - -protected: - map read_write_bindings; - map read_only_bindings; - map variables; - map node_expression; - map requires_paranthesis; - unordered_set lines_to_remove; - vector additional_lines; - unordered_map name_count; - - virtual void GenerateArgumentNames(ArgumentManager& args) { - for (auto& arg : args.Inputs()) { - Node* node = arg.second; - ArgID id = arg.first; - string name = node->var_name; - bool need_parenthesis = false; - if (variables.contains(node)) { - name = GetName("var") + name; - } else { - string expr = node_expression[node]; - bool is_memory = node->op->HasAllTypes(OpProp::Memory); - bool is_static = node->op->HasAllTypes(OpProp::Static) || - node->op->HasAllTypes(OpProp::CantSubstitute); - bool is_constant = node->op->class_ == OpClass::Constant; - bool is_variable = node->op->class_ == OpClass::Variable; - if (is_constant && expr == "") { - expr = node->GetTensor()->GetConstantString(); - } - bool has_name = node->debug_name != ""; - bool has_single_output = (node->args.OutputCount() == 1) || is_constant || is_variable; - bool modified = node->flags.has(NodeProp::Modified); - bool short_enough = expr.size() < 100; - bool can_substitude = !has_name && has_single_output && !modified && short_enough && !is_static && !is_memory; - if (can_substitude) { - if (expr == "") { - throw std::runtime_error("Substitute expression is empty"); - } - name = expr; - need_parenthesis = requires_paranthesis[node]; - lines_to_remove.insert(node); - } - } - args.SetName(id, name, need_parenthesis); - } - } - - void RegenerateNodeName(Node* node) { - if(node->op->HasAllTypes(OpProp::LocalMemory)) { - used_names_.insert(node->var_name); - return; - } - string debug = node->debug_name; - if (debug.empty()) { - debug = "v" + node->name;//return; - } - if (IsForbiddenName(debug)) { - debug = debug + "0"; - } - // check if the name is already used - if (name_count.contains(debug)) { - name_count[debug]++; - debug = debug + "_" + to_string(name_count[debug]); - } else { - name_count[debug] = 1; - } - if (used_names_.contains(debug)) { - debug = debug + "_"; - } - node->var_name = debug; - } - - Line* GenerateLine(Node* node); - - virtual string GenerateLoop(ArgumentManager* args, const string& name) - { - string in1 = args->Name(ArgType::Input, 0), in2 = args->Name(ArgType::Input, 1), in3 = args->Name(ArgType::Input, 2); - return "for (int " + name + " = " + in1 + "; " + name + " < " + in2 + "; " + name + " += " + in3 + ")"; - } - - virtual string GenerateIf(ArgumentManager* args) - { - return "if (" + args->Name(ArgType::Input, 0) + ")"; - } - - virtual string TypeCast(string type_name, string input) - { - return "((" + type_name + ")(" + input + "))"; - } - - virtual string TypeReinterpret(string type_name, string input) { - return "as" + type_name + "(" + input + ")"; - } - - virtual string GenerateTypeCast(ArgumentManager* args, const string& type_name) - { - return TypeCast(type_name, args->Name(ArgType::Input, 0)); - } - - virtual string GenerateTypeReinterpret(ArgumentManager* args, const string& type_name) - { - return TypeReinterpret(type_name, args->Name(ArgType::Input, 0)); - } - - virtual string GenerateAtomicOp(const string& op, - const string& input_type_name, - const string& output_type_name, - const string& address, const string& input, const string& output, const string& memory_name) - { - return op + "((" + input_type_name + "*)" + memory_name + ", " + address + ", " + input + ")"; - } - - string GetName(const string& name) { - // Check if the function name is in the map - if (name_map_.find(name) != name_map_.end()) { - return name_map_[name]; - } - - // If not, return the original name - return name; - } -}; - -string AddIndent(const string& input, const string& indent); - - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/CodeGen/Langs/CPP.cpp b/TensorFrost/Backend/CodeGen/Langs/CPP.cpp deleted file mode 100644 index a25b5f68..00000000 --- a/TensorFrost/Backend/CodeGen/Langs/CPP.cpp +++ /dev/null @@ -1,781 +0,0 @@ -#include "Backend/CodeGen/Generators.h" - -namespace TensorFrost { -using namespace std; - - -class CPPGenerator : public CodeGenerator { -public: - CPPGenerator(IR* ir) : CodeGenerator(ir) { - name_map_ = { - {"sqrt", "sqrtf"}, - {"var", "var_"}, - }; - } -}; - -string GetCPPHeader() { - string header = R"( -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -typedef uint32_t uint; - -inline int min(int a, int b) -{ - return a < b ? a : b; -} - -inline int max(int a, int b) -{ - return a > b ? a : b; -} - -inline float min(float a, float b) -{ - return a < b ? a : b; -} - -inline float max(float a, float b) -{ - return a > b ? a : b; -} - -inline float asfloat(uint x) -{ - return *(float*)&x; -} - -inline uint asuint(float x) -{ - return *(uint*)&x; -} - -inline uint asuint(double x) -{ - return asuint((float)x); -} - -inline uint asuint(int x) -{ - return *(uint*)&x; -} - -inline uint asuint(uint x) -{ - return *(uint*)&x; -} - -inline int asint(uint x) -{ - return *(int*)&x; -} - -inline uint asuint(bool x) -{ - return *(uint*)&x; -} - -inline bool asbool(uint x) -{ - return *(bool*)&x; -} - -inline int clamp(int x, int a, int b) -{ - return min(max(x, a), b); -} - -inline float clamp(float x, float a, float b) -{ - return min(max(x, a), b); -} - -inline float lerp(float a, float b, float t) -{ - return a + (b - a) * t; -} - -inline float smoothstep(float a, float b, float t) -{ - t = clamp((t - a) / (b - a), 0.0f, 1.0f); - return t * t * (3.0f - 2.0f * t); -} - -inline float sign(float x) -{ - return x < 0.0f ? -1.0f : 1.0f; -} - -inline int sign(int x) -{ - return x < 0 ? -1 : 1; -} - -inline uint reversebits(uint x) -{ - x = (((x & 0xaaaaaaaa) >> 1) | ((x & 0x55555555) << 1)); - x = (((x & 0xcccccccc) >> 2) | ((x & 0x33333333) << 2)); - x = (((x & 0xf0f0f0f0) >> 4) | ((x & 0x0f0f0f0f) << 4)); - x = (((x & 0xff00ff00) >> 8) | ((x & 0x00ff00ff) << 8)); - return ((x >> 16) | (x << 16)); -} - -inline int reversebits(int x) -{ - return (uint)reversebits((uint)x); -} - -inline void InterlockedAdd(int* memory, int address, int value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - place->fetch_add(value, std::memory_order_relaxed); -} - -inline void InterlockedAdd(uint* memory, int address, uint value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - place->fetch_add(value, std::memory_order_relaxed); -} - -inline void InterlockedAdd(float* memory, int address, float value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - float current = place->load(std::memory_order_relaxed); - float goal = current + value; - while (!place->compare_exchange_weak(current, goal, std::memory_order_release, std::memory_order_relaxed)) { - goal = current + value; - } -} - -inline int InterlockedAdd_Prev(int* memory, int address, int value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - return place->fetch_add(value, std::memory_order_relaxed); -} - -inline uint InterlockedAdd_Prev(uint* memory, int address, uint value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - return place->fetch_add(value, std::memory_order_relaxed); -} - -inline float InterlockedAdd_Prev(float* memory, int address, float value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - float current = place->load(std::memory_order_relaxed); - float goal = current + value; - while (!place->compare_exchange_weak(current, goal, std::memory_order_release, std::memory_order_relaxed)) { - goal = current + value; - } - return current; -} - -inline void InterlockedAnd(int* memory, int address, int value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - place->fetch_or(value, std::memory_order_relaxed); -} - -inline void InterlockedAnd(uint* memory, int address, uint value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - place->fetch_and(value, std::memory_order_relaxed); -} - -inline void InterlockedOr(int* memory, int address, int value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - place->fetch_or(value, std::memory_order_relaxed); -} - -inline void InterlockedOr(uint* memory, int address, uint value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - place->fetch_or(value, std::memory_order_relaxed); -} - -inline void InterlockedXor(int* memory, int address, int value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - place->fetch_xor(value, std::memory_order_relaxed); -} - -inline void InterlockedXor(uint* memory, int address, uint value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - place->fetch_xor(value, std::memory_order_relaxed); -} - -inline void InterlockedMin(int* memory, int address, int value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - int current = place->load(std::memory_order_relaxed); - int goal = min(current, value); - while (!place->compare_exchange_weak(current, goal, std::memory_order_release, std::memory_order_relaxed)) { - goal = min(current, value); - } -} - -inline void InterlockedMin(float* memory, int address, float value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - float current = place->load(std::memory_order_relaxed); - float goal = min(current, value); - while (!place->compare_exchange_weak(current, goal, std::memory_order_release, std::memory_order_relaxed)) { - goal = min(current, value); - } -} - -inline void InterlockedMax(int* memory, int address, int value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - int current = place->load(std::memory_order_relaxed); - int goal = max(current, value); - while (!place->compare_exchange_weak(current, goal, std::memory_order_release, std::memory_order_relaxed)) { - goal = max(current, value); - } -} - -inline void InterlockedMax(float* memory, int address, float value) -{ - std::atomic* place = reinterpret_cast*>(&memory[address]); - float current = place->load(std::memory_order_relaxed); - float goal = max(current, value); - while (!place->compare_exchange_weak(current, goal, std::memory_order_release, std::memory_order_relaxed)) { - goal = max(current, value); - } -} - -inline void group_barrier() {} //NOOP, TODO: implement properly - -inline uint pcg(uint v) -{ - uint state = v * 747796405u + 2891336453u; - uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; - return (word >> 22u) ^ word; -} - -inline float pcgf(uint v) -{ - return (float)pcg(v) / (float)0xffffffffu; -} - -extern "C" { - enum TFType { - Float, - Uint, - Int, - Bool, - None, - }; - - struct TFDataFormat { - TFType type; - size_t size; - - bool operator==(const TFDataFormat& other) const { - return type == other.type && size == other.size; - } - - bool operator!=(const TFDataFormat& other) const { - return !(*this == other); - } - - int GetHash() const { - return (int)type << 16 | (int)size; - } - - bool operator<(const TFDataFormat& other) const { - return GetHash() < other.GetHash(); - } - - bool operator>(const TFDataFormat& other) const { - return GetHash() > other.GetHash(); - } - }; - -#define TFTypeNone TFDataFormat{TFType::None, 0} -#define TFTypeBool32 TFDataFormat{TFType::Bool, 32} -#define TFTypeFloat32 TFDataFormat{TFType::Float, 32} -#define TFTypeInt32 TFDataFormat{TFType::Int, 32} -#define TFTypeUint32 TFDataFormat{TFType::Uint, 32} - - struct TFBuffer { - size_t size = 0; - size_t used_size = 0; - size_t time_since_used = 0; - bool up_to_date = false; - bool read_only = false; - const char* name = nullptr; - //add type descriptor (for special kinds of buffers) - }; - - struct TFTensor { - TFBuffer* buffer; - TFDataFormat format; - size_t dim; - const size_t* shape; - }; - - struct TFTensorList { - size_t count; - const TFTensor* tensors; - }; - - struct TFDispatchInfo { - size_t kernel_id; - size_t read_write_count; - const TFTensor* read_write_tensors; - size_t read_only_count; - const TFTensor* read_only_tensors; - size_t variable_count; - const uint32_t* variables; - size_t work_group_count; - }; - - typedef TFTensor alloc_func(const char*, const size_t*, size_t, TFDataFormat, void*); - typedef void dealloc_func(TFTensor, void*); - typedef uint readback_func(TFTensor, size_t, void*); - typedef void writeback_func(TFTensor, size_t, uint32_t, void*); - typedef void dispatch_func(TFDispatchInfo, void*); - typedef void region_func(const char*, bool, void*); - - struct TFRuntime { - alloc_func* alloc; - dealloc_func* dealloc; - readback_func* readback; - writeback_func* writeback; - dispatch_func* dispatch; - region_func* region; - void* custom_data; - }; -} - -class TFContext -{ -public: - TFRuntime runtime; - - TFContext(TFRuntime runtime); - size_t compute_size(const size_t* shape, size_t dim); - TFTensor allocate(std::string name, std::initializer_list shape, TFDataFormat type); - void deallocate(TFTensor tensor); - void check_tensor(TFTensor tensor, std::string name, std::initializer_list target_shape, TFDataFormat target_type); - TFTensor reshape(TFTensor tensor, std::string name, std::initializer_list shape, TFDataFormat type); - TFTensor assert_tensor(TFTensor tensor, std::string name, std::initializer_list target_shape, TFDataFormat target_type); - uint32_t read(TFTensor tensor, size_t index); - void write(TFTensor tensor, size_t index, uint32_t value); - void dispatch(size_t kernel_id, std::initializer_list read_write, std::initializer_list read_only, std::initializer_list var, std::initializer_list shape, std::initializer_list group); - void region_begin(std::string name); - void region_end(std::string name); - template - void print_value(std::string name, T value); - void assert_value(std::string name, bool condition); -}; -)"; - - return header; -} - -string GetCPPImplementation() { - string implementation = R"( - -std::unordered_map TFTypeNames = { - {TFType::Float, "Float"}, {TFType::Uint, "Uint"}, - {TFType::Int, "Int"}, {TFType::Bool, "Bool"}, - {TFType::None, "None"}, -}; - -TFContext::TFContext(TFRuntime runtime) : runtime(runtime) {} - -size_t TFContext::compute_size(const size_t* shape, size_t dim) { - size_t size = 1; - for (size_t i = 0; i < dim; i++) { - size *= shape[i]; - } - return size; -} - -TFTensor TFContext::allocate(std::string name, std::initializer_list shape, TFDataFormat type) -{ - const size_t* shape_arr = shape.begin(); - size_t dim = shape.size(); - size_t size = compute_size(shape_arr, dim); - - for (size_t i = 0; i < dim; i++) { - if(shape_arr[i] < 1) { - throw std::runtime_error("Invalid shape on dimension " + std::to_string(i) + " for " + name + ". Expected positive integer, got " + std::to_string(shape_arr[i])); - } - } - - return runtime.alloc(name.c_str(), shape_arr, dim, type, runtime.custom_data); -} - -void TFContext::deallocate(TFTensor tensor) -{ - runtime.dealloc(tensor, runtime.custom_data); -} - -void TFContext::check_tensor(TFTensor tensor, std::string name, std::initializer_list target_shape, TFDataFormat target_format) -{ - const size_t* shape_arr = tensor.shape; - const size_t* target_shape_arr = target_shape.begin(); - size_t target_dim = target_shape.size(); - - if (tensor.format != target_format) { - throw std::runtime_error("Invalid type for " + name + ". Expected " + TFTypeNames[target_format.type] + ", got " + TFTypeNames[tensor.format.type]); - } - - if (tensor.dim != target_dim) { - throw std::runtime_error("Invalid number of dimensions for " + name + ". Expected " + std::to_string(target_dim) + ", got " + std::to_string(tensor.dim)); - } - - for (size_t i = 0; i < tensor.dim; i++) { - if (shape_arr[i] != target_shape_arr[i] || target_shape_arr[i] < 1) { - throw std::runtime_error("Invalid shape for dimension " + std::to_string(i) + " in " + name + ". Expected " + std::to_string(target_shape_arr[i]) + ", got " + std::to_string(shape_arr[i])); - } - } -} - -TFTensor TFContext::reshape(TFTensor tensor, std::string name, std::initializer_list shape, TFDataFormat type) -{ - size_t* new_shape = new size_t[shape.size()]; - std::copy(shape.begin(), shape.end(), new_shape); - TFTensor new_tensor = {tensor.buffer, type, shape.size(), new_shape}; - - size_t old_size = compute_size(tensor.shape, tensor.dim); - size_t new_size = compute_size(new_tensor.shape, new_tensor.dim); - - if(old_size != new_size) { - throw std::runtime_error("Cannot reshape " + name + ", expected " + std::to_string(new_size) + " elements, while input has " + std::to_string(old_size)); - } - - return new_tensor; -} - -TFTensor TFContext::assert_tensor(TFTensor tensor, std::string name, std::initializer_list target_shape, TFDataFormat target_type) -{ - check_tensor(tensor, name, target_shape, target_type); - return tensor; -} - -uint TFContext::read(TFTensor tensor, size_t index) -{ - return runtime.readback(tensor, index, runtime.custom_data); -} - -void TFContext::write(TFTensor tensor, size_t index, uint32_t value) -{ - runtime.writeback(tensor, index, value, runtime.custom_data); -} - -void TFContext::dispatch(size_t kernel_id, std::initializer_list read_write, std::initializer_list read_only, std::initializer_list var, std::initializer_list shape, std::initializer_list group) -{ - //currently only supports read_write tensors - std::vector all_tensors; - all_tensors.insert(all_tensors.end(), read_write.begin(), read_write.end()); - all_tensors.insert(all_tensors.end(), read_only.begin(), read_only.end()); - std::vector all_vars; - all_vars.insert(all_vars.end(), var.begin(), var.end()); - all_vars.push_back(0); //group index offset - TFDispatchInfo info = {kernel_id, all_tensors.size(), all_tensors.data(), 0, nullptr, (uint)all_vars.size(), all_vars.data(), 0}; - - const TFTensor* read_write_tensors = read_write.begin(); - for (size_t i = 0; i < read_write.size(); i++) { - read_write_tensors[i].buffer->up_to_date = false; - } - - const size_t* shape_arr = shape.begin(); - const size_t* group_arr = group.begin(); - size_t dispatch_dim = shape.size(); - size_t group_dim = group.size(); - - size_t work_group_count = 1; - for (size_t i = 0; i < dispatch_dim - group_dim; i++) { - work_group_count *= shape_arr[i]; - } - - //only the last dimensions are divided by the group size - //TODO: consider reversing indices on backend too - for (size_t i = 0; i < group_dim; i++) { - size_t dim = shape_arr[dispatch_dim - i - 1]; - work_group_count *= (dim + group_arr[i] - 1) / group_arr[i]; - } - - info.work_group_count = work_group_count; - - runtime.dispatch(info, runtime.custom_data); -} - -void TFContext::region_begin(std::string name) -{ - if(runtime.region == nullptr) { - return; - } - runtime.region(name.c_str(), true, runtime.custom_data); -} - -void TFContext::region_end(std::string name) -{ - if(runtime.region == nullptr) { - return; - } - runtime.region(name.c_str(), false, runtime.custom_data); -} - -template -void TFContext::print_value(std::string name, T value) -{ - std::cout << name << ": " << value << std::endl; -} - -void TFContext::assert_value(std::string name, bool condition) -{ - if(!condition) { - throw std::runtime_error("Assertion failed: " + name); - } -} - -)"; - return implementation; -} - -void GenerateCode(Program* program) { - string final_source = GetCPPHeader(); - final_source += GetCPPImplementation(); - - GenerateNodeNames(*program->ir_); - int input_count = (int)program->ir_->input_memory_map.size(); - int output_count = (int)program->ir_->output_memory_map.size(); - - // Generate code for each compute kernel - map dispatch_code; - - for (auto& kernel : program->kernels_) { - global_kernel_manager->AddKernelID(program, &kernel); - kernel.kernel_name_ = "kernel_" + to_string(kernel.kernel_id_); - - // Generate kernel - vector read_write_nodes; - read_write_nodes.resize(kernel.read_write_memory.size()); - for (auto& read_write : kernel.read_write_memory) { - read_write_nodes[read_write.second] = read_write.first; - } - vector read_only_nodes; - read_only_nodes.resize(kernel.read_only_memory.size()); - for (auto& read_only : kernel.read_only_memory) { - read_only_nodes[read_only.second] = read_only.first; - } - - vector variable_nodes; - variable_nodes.resize(kernel.variables.size()); - for (auto& variable : kernel.variables) { - variable_nodes[variable.second] = variable.first; - } - - string read_write_args = "{"; - for (int d = 0; d < read_write_nodes.size(); d++) { - if (d != 0) { - read_write_args += ", "; - } - read_write_args += read_write_nodes[d]->var_name; - } - read_write_args += "}"; - string read_only_args = "{"; - for (int d = 0; d < read_only_nodes.size(); d++) { - if (d != 0) { - read_only_args += ", "; - } - read_only_args += read_only_nodes[d]->var_name; - } - read_only_args += "}"; - - string variable_args = "{"; - for (int d = 0; d < variable_nodes.size(); d++) { - if (d != 0) { - variable_args += ", "; - } - variable_args += "asuint(" + ReadVariable(variable_nodes[d]) + ")"; - } - variable_args += "}"; - - string shape_args = "{"; - int dims = (int)kernel.shape.size(); - for (int d = 0; d < kernel.shape.size(); d++) { - if (d != 0) { - shape_args += ", "; - } - shape_args += "(uint)" + ReadVariable(kernel.shape[ArgID(ArgType::Shape, dims - d - 1)]); - } - shape_args += "}"; - - string group_args = "{"; - for (int d = 0; d < kernel.root->group_size.size(); d++) { - if (d != 0) { - group_args += ", "; - } - group_args += to_string(kernel.root->group_size[d]); - } - group_args += "}"; - - GenerateKernel(program, &kernel); - - if (current_backend == BackendType::CPU) { - final_source += kernel.full_generated_code_; - } - - dispatch_code[kernel.root] = "tf.dispatch(" + to_string(kernel.kernel_id_) + ", " + read_write_args + ", " + read_only_args + ", " + variable_args + ", " + shape_args + ", " + group_args + ")"; - } - - GenerateMain(program, dispatch_code, input_count, output_count); - - final_source += program->main_function_; - - string host_code = - "\n" - "extern \"C\" " -#ifdef _WIN32 - "__declspec(dllexport) " -#endif - "int " - "main" - "(TFTensor* in, TFTensor* out, TFRuntime runtime)\n" - "{\n" - " auto outputs = " + program->program_name + "(TFContext(runtime)"; - - if (input_count > 0) { - host_code += ", "; - } - - for (int i = 0; i < input_count; i++) { - host_code += "in[" + to_string(i) + "]"; - if (i != input_count - 1) { - host_code += ", "; - } - } - host_code += ");\n"; - - for (int i = 0; i < output_count; i++) { - host_code += " out[" + to_string(i) + "] = std::get<" + to_string(i) + ">(outputs);\n"; - } - - host_code += " return 0;\n}\n"; - - final_source += host_code; - - program->generated_code_ = final_source; -} -void GenerateMain(Program* program, map& dispatch_code, - int input_count, int output_count) { - CPPGenerator generator = CPPGenerator(program->ir_); - generator.custom_generated_code_ = dispatch_code; - generator.GenerateCode(program->ir_->root); - - string main_code = "\nstd::tuple<"; - for (int i = 0; i < output_count; i++) { - main_code += "TFTensor"; - if (i != output_count - 1) { - main_code += ", "; - } - } - main_code += "> " + program->program_name + "(TFContext tf"; - if (input_count > 0) { - main_code += ", "; - } - for (int i = 0; i < input_count; i++) { - Node* input_node = program->ir_->input_memory_map[i]; - main_code += "TFTensor " + input_node->var_name; - if (i != input_count - 1) { - main_code += ", "; - } - } - main_code += ")\n{\n"; - - main_code += AddIndent(generator.AssembleString(), " "); - - main_code += " return {"; - - for (int i = 0; i < output_count; i++) { - Node* output_node = program->ir_->output_memory_map[i]; - main_code += output_node->var_name; - if (i != output_count - 1) { - main_code += ", "; - } - } - main_code += "};\n}\n"; - - program->main_function_ = main_code; -} - -void GenerateCPPKernel(Program* program, Kernel* kernel) { - CPPGenerator generator = CPPGenerator(program->ir_); - generator.GenerateKernelCode(kernel); - string kernel_code = generator.AssembleString(); - - string loop = ""; - loop += GetBufferDeclarations(kernel, [](const string& name, const string& type_name, size_t binding) { - return " uint* " + name + "_mem = mem[" + to_string(binding) + "];\n"; - }); - - kernel->var_names = vector(kernel->variables.size()); - kernel->var_types = vector(kernel->variables.size()); - for (auto var : kernel->variables) { - kernel->var_names[var.second] = var.first->var_name; - kernel->var_types[var.second] = type_names[var.first->format.type]; - } - kernel->var_names.push_back("_kernel_block_offset"); - kernel->var_types.push_back(type_names[TFType::Uint]); - for (int i = 0; i < kernel->var_names.size(); i++) { - loop += " " + kernel->var_types[i] + " var_" + kernel->var_names[i] + " = as" + kernel->var_types[i] + "(var[" + to_string(i) + "]);\n"; - } - - loop += " #pragma omp parallel for\n"; - loop += " for (int block_id = var__kernel_block_offset; block_id < (work_group_count+var__kernel_block_offset); block_id++)\n"; - loop += " {\n"; - - loop += GetGroupBufferDeclarations(kernel, [](const string& name, const string& type_name, size_t size) { - return " " + type_name + " " + name + "[" + to_string(size) + "];\n"; - }); - - for (int d = 0; d < kernel->root->group_size.size(); d++) { - int dim = (int)kernel->root->group_size.size() - d - 1; - loop += " for (int block_thread_id" + to_string(dim) + - " = 0; block_thread_id" + to_string(dim) + " < " + - to_string(kernel->root->group_size[dim]) + "; block_thread_id" + - to_string(dim) + "++)\n"; - } - loop += " {\n"; - string loop_end = " }\n"; - loop_end += " }\n"; - loop += AddIndent(kernel_code, " "); - loop += loop_end; - - string kernel_source = - "\n" - "extern \"C\" " - #ifdef _WIN32 - "__declspec(dllexport) " - #endif - "void " + - kernel->kernel_name_ + - "(uint* var, uint** mem, uint work_group_count)\n" - "{\n" + loop + - "}\n"; - - kernel->full_generated_code_ = kernel_source; - kernel->generated_header_ = ""; - kernel->generated_bindings_ = ""; - kernel->generated_main_ = kernel_source; -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/CodeGen/Langs/GLSL.cpp b/TensorFrost/Backend/CodeGen/Langs/GLSL.cpp deleted file mode 100644 index a7a83558..00000000 --- a/TensorFrost/Backend/CodeGen/Langs/GLSL.cpp +++ /dev/null @@ -1,188 +0,0 @@ -#include "Backend/CodeGen/Generators.h" - -namespace TensorFrost { -using namespace std; - -class GLSLGenerator : public CodeGenerator { - public: - GLSLGenerator(IR* ir) : CodeGenerator(ir) { - name_map_ = { - {"var", "var."}, - {"modf", "mod"}, - {"atan2", "atan"}, - {"lerp", "mix"}, - {"reversebits", "bitfieldReverse"}, - {"frac", "fract"}, - {"group_barrier", "barrier"} - }; - } - - string TypeCast(string type_name, string input) override { - return type_name + "(" + input + ")"; - } - - string GenerateAtomicOp(const string& op, const string& input_type_name, - const string& output_type_name, - const string& address, const string& input, const string& output, const string& memory_name) override { - if (op == "InterlockedAdd") { - if(input_type_name == "float") - { - return "atomicAdd_"+memory_name+"(" + address + ", " + input + ")"; - } - return "atomicAdd("+memory_name+"[" + address + "], uint(" + input + "))"; - } else if (op == "InterlockedAdd_Prev") { - if(input_type_name == "float") - { - return output_type_name + "(atomicAdd_"+memory_name+"(" + address + ", " + input +"))"; - } - return output_type_name + "(atomicAdd("+memory_name+"[" + address + "], uint(" + input + ")))"; - } else if (op == "InterlockedMin") { - return "atomicMin("+memory_name+"[" + address + "], uint(" + input + "))"; - } else if (op == "InterlockedMax") { - return "atomicMax("+memory_name+"[" + address + "], uint(" + input + "))"; - } else if (op == "InterlockedAnd") { - return "atomicAnd("+memory_name+"[" + address + "], uint(" + input + "))"; - } else if (op == "InterlockedOr") { - return "atomicOr("+memory_name+"[" + address + "], uint(" + input + "))"; - } else if (op == "InterlockedXor") { - return "atomicXor("+memory_name+"[" + address + "], uint(" + input + "))"; - } - else - { - throw runtime_error("Unsupported atomic operation: " + op); - } - } - -}; - -string GetGLSLHeader(Kernel* kernel) { - string header = R"( -#version 430 - -uint pcg(uint v) { - uint state = v * 747796405u + 2891336453u; - uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; - return (word >> 22u) ^ word; -} - -float pcgf(uint v) { - return float(pcg(v)) / float(0xffffffffu); -} - -float asfloat(uint x) { - return uintBitsToFloat(x); -} - -uint asuint(float x) { - return floatBitsToUint(x); -} - -uint asuint(bool x) { - return uint(x); -} - -uint asuint(int x) { - return uint(x); -} - -uint asuint(uint x) { - return x; -} - -int asint(uint x) { - return int(x); -} - -bool asbool(uint x) { - return bool(x); -} - -)"; - kernel->var_names = vector(kernel->variables.size()); - kernel->var_types = vector(kernel->variables.size()); - header += "\nstruct UBO {\n"; - for (auto var : kernel->variables) { - kernel->var_names[var.second] = var.first->var_name; - kernel->var_types[var.second] = type_names[var.first->format.type]; - } - kernel->var_names.push_back("_kernel_block_offset"); - kernel->var_types.push_back(type_names[TFType::Uint]); - for (int i = 0; i < kernel->var_names.size(); i++) { - header += " " + kernel->var_types[i] + " " + kernel->var_names[i] + ";\n"; - } - header += "};\n\n"; - return header; -} - -string GLSLBufferDeclaration(const string& name, const string& type_name, const size_t binding) { - string decl = "layout(std430, binding = " + to_string(binding) + ") buffer buf_" + name + " {\n " + type_name + " " + name + "_mem[];\n};\n"; - //add atomic functions - decl += R"( -float atomicAdd_)" + name + R"(_mem(int index, float val) { - uint uval = floatBitsToUint(val); - uint tmp0 = 0; - uint tmp1 = 0; - - while (true) { - tmp0 = atomicCompSwap()" + name + R"(_mem[index], tmp1, uval); - if (tmp1 == tmp0) break; - tmp1 = tmp0; - uval = floatBitsToUint(val + uintBitsToFloat(tmp1)); - } - - return uintBitsToFloat(tmp1); -} - -)"; - - return decl; -} - -string GLSLGroupBufferDeclaration(const string& name, const string& type_name, const size_t size) { - string decl = "shared " + type_name + " " + name + "[" + to_string(size) + "];\n"; - return decl; -} - -void GenerateGLSLKernel(Program* program, Kernel* kernel) { - kernel->generated_header_ = GetGLSLHeader(kernel); - - string buffers = GetBufferDeclarations(kernel, GLSLBufferDeclaration) + "\n"; - buffers += "layout(std140) uniform UBOBlock {\n UBO var;\n};\n\n"; - kernel->generated_bindings_ = buffers; - - string main_code = ""; - - main_code += GetGroupBufferDeclarations(kernel, GLSLGroupBufferDeclaration) + "\n"; - - vector group_size = kernel->root->group_size; - //pad with 1s - while (group_size.size() < 3) { - group_size.push_back(1); - } - - main_code += "layout (local_size_x = " + to_string(group_size[0]) + ", local_size_y = " + to_string(group_size[1]) + ", local_size_z = " + to_string(group_size[2]) + ") in;\n"; - - - main_code += R"( -void main() { - int block_id = int(gl_WorkGroupID.x + var._kernel_block_offset); - int block_thread_id0 = int(gl_LocalInvocationID.x); - int block_thread_id1 = int(gl_LocalInvocationID.y); - int block_thread_id2 = int(gl_LocalInvocationID.z); - -)"; - - GLSLGenerator generator = GLSLGenerator(program->ir_); - generator.GenerateKernelCode(kernel); - string kernel_code = generator.AssembleString(); - - main_code += AddIndent(kernel_code, " "); - - main_code += "}\n"; - - kernel->generated_main_ = main_code; - - kernel->full_generated_code_ = kernel->generated_header_ + kernel->generated_bindings_ + kernel->generated_main_; -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/CodeGen/Langs/HLSL.cpp b/TensorFrost/Backend/CodeGen/Langs/HLSL.cpp deleted file mode 100644 index b861b1ca..00000000 --- a/TensorFrost/Backend/CodeGen/Langs/HLSL.cpp +++ /dev/null @@ -1,135 +0,0 @@ -#include "Backend/CodeGen/Generators.h" - -namespace TensorFrost { -using namespace std; - -class HLSLGenerator : public CodeGenerator { - public: - HLSLGenerator(IR* ir) : CodeGenerator(ir) { - name_map_ = { - {"var", "var."}, - {"group_barrier", "GroupMemoryBarrierWithGroupSync"} - }; - } - - string TypeCast(string type_name, string input) override { - return type_name + "(" + input + ")"; - } - - string GenerateAtomicOp(const string& op, const string& input_type_name, - const string& output_type_name, const string& address, - const string& input, const string& output, const string& memory_name) override - { - if (op == "InterlockedAdd") { - if(input_type_name == "float") - { - return "InterlockedAddF("+memory_name+", " + address + ", " + input + ")"; - } - return "InterlockedAdd("+memory_name+"[" + address + "], " + input + ")"; - } else if (op == "InterlockedAdd_Prev") { - if(input_type_name == "float") - { - return "InterlockedAddF("+memory_name+", " + address + ", " + input + ")"; - } - additional_lines.push_back("InterlockedAdd("+memory_name+"[" + address + "], " + - input + ", " + output + ");"); - return "0"; - } else { - return op + "("+memory_name+"[" + address + "], " + input + ")"; - } - } -}; - -string GetHLSLHeader(Kernel* kernel) { - string header =R"( -uint pcg(uint v) -{ - uint state = v * 747796405u + 2891336453u; - uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; - return (word >> 22u) ^ word; -} - -float pcgf(uint v) -{ - return float(pcg(v)) / float(0xffffffffu); -} - -float InterlockedAddF(RWStructuredBuffer buffer, int index, float val) -{ - uint uval = asuint(val), tmp0 = 0, tmp1 = 0; - [allow_uav_condition] while (true) { - InterlockedCompareExchange(buffer[index], tmp0, uval, tmp1); - if (tmp1 == tmp0) break; - tmp0 = tmp1; - uval = asuint(val + asfloat(tmp1)); - } - return asfloat(tmp1); -} - -)"; - kernel->var_names = vector(kernel->variables.size()); - kernel->var_types = vector(kernel->variables.size()); - header += "\nstruct UBO {\n"; - for (auto var : kernel->variables) { - kernel->var_names[var.second] = var.first->var_name; - kernel->var_types[var.second] = type_names[var.first->format.type]; - } - kernel->var_names.push_back("_kernel_block_offset"); - kernel->var_types.push_back(type_names[TFType::Uint]); - for (int i = 0; i < kernel->var_names.size(); i++) { - header += " " + kernel->var_types[i] + " " + kernel->var_names[i] + ";\n"; - } - header += "};\n\n"; - return header; -} - -string HLSLBufferDeclaration(const string& name, const string& type_name, const size_t binding) { - return "RWStructuredBuffer<" + type_name + "> " + name + "_mem : register(u" + to_string(binding) + ");\n"; -} - -string HLSLGroupBufferDeclaration(const string& name, const string& type_name, const size_t size) { - string decl = "groupshared " + type_name + " " + name + "[" + to_string(size) + "];\n"; - return decl; -} - -void GenerateHLSLKernel(Program* program, Kernel* kernel) { - kernel->generated_header_ = GetHLSLHeader(kernel); - - kernel->generated_bindings_ = GetBufferDeclarations(kernel, HLSLBufferDeclaration) + "\n"; - kernel->generated_bindings_ += "cbuffer ubo : register(b0) { UBO var; }\n"; - - vector group_size = kernel->root->group_size; - // pad with 1s - while (group_size.size() < 3) { - group_size.push_back(1); - } - - string main_function = ""; - - main_function += GetGroupBufferDeclarations(kernel, HLSLGroupBufferDeclaration) + "\n"; - - main_function += "[numthreads(" + to_string(group_size[0]) + ", " + to_string(group_size[1]) + ", " + to_string(group_size[2]) + ")]"; - - main_function += R"( -void main(uint3 gtid : SV_GroupThreadID, uint3 gid : SV_GroupID) -{ - int block_id = gid.x + var._kernel_block_offset; - int block_thread_id0 = gtid.x; - int block_thread_id1 = gtid.y; - int block_thread_id2 = gtid.z; - -)"; - - HLSLGenerator generator = HLSLGenerator(program->ir_); - generator.GenerateKernelCode(kernel); - string kernel_code = generator.AssembleString(); - - main_function += AddIndent(kernel_code, " "); - - main_function += "}\n"; - kernel->generated_main_ = main_function; - - kernel->full_generated_code_ = kernel->generated_header_ + kernel->generated_bindings_ + kernel->generated_main_; -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/CodeGen/Langs/Listing.cpp b/TensorFrost/Backend/CodeGen/Langs/Listing.cpp deleted file mode 100644 index 0704ed04..00000000 --- a/TensorFrost/Backend/CodeGen/Langs/Listing.cpp +++ /dev/null @@ -1,155 +0,0 @@ -#include "Backend/CodeGen/Generators.h" -#include "Compiler/KernelGen.h" - -namespace TensorFrost { -using namespace std; - -string GetNodeString(const Node* node, bool verbose) { - string listing = ""; - if (node->format.type != TFType::None) { - listing += DataTypeToString(node->format.type) + "(" + to_string(node->format.size) + ") "; - } - - if (node->format.type != TFType::None) { - // the tensor name - listing += node->var_name + "(" + to_string(node->format.size) + ") = "; - } - - listing += node->name + "("; - - const ArgumentManager& args = node->args; - - auto ArgTypePrint = [&](string name, ArgType type) { - if (args.Has(type)) { - string arr = name + "=["; - for (int i = 0; i < args.Count(type); i++) { - if (i != 0) arr += ","; - arr += GetNodeName(args.Get(type, i), false); - } - arr += "], "; - return arr; - } - return string(); - }; - - listing += ArgTypePrint("memory", ArgType::Memory); - listing += ArgTypePrint("inputs", ArgType::Input); - listing += ArgTypePrint("indices", ArgType::Index); - - if(node->args.OutputCount() > 0) { - listing += "outputs=["; - for(auto [in, out]: node->args.Outputs()) { - listing += GetNodeName(out, false) + ", "; - } - listing += "], "; - } - - if (!node->data.empty()) { - listing += "data=["; - for (int i = 0; i < node->data.size(); i++) { - if (i != 0) listing += ","; - listing += to_string(node->data[i]); - } - listing += "], "; - } - - if(node->flags.count() > 0) { - listing += "flags={"; - auto flags = node->flags.get_data(); - for(auto flad_data : flags) { - NodeProp flag = flad_data.first; - int64_t data = flad_data.second; - listing += NodeFlagsToString(flag); - if(data >= 0) { - listing += "(" + to_string(data) + ")"; - } - listing += ", "; - } - listing += "}, "; - } - - if (node->cost_ >= 0) { - listing += "cost=" + to_string(node->cost_) + ", "; - } - - if(node->memory_deps.size() > 0) { - listing += "memory_deps_count=" + to_string(node->memory_deps.size()) + ", "; - } - - if(node->indexing_mode_ != IndexingMode::Clamp) { - listing += "indexing_mode=" + IndexingModeToString(node->indexing_mode_) + ", "; - } - - listing += ArgTypePrint("shape", ArgType::Shape); - -#ifdef _DEBUG - if (verbose) { - listing += "index=" + to_string(node->index_) + ", "; - listing += "debug_index=" + to_string(node->debug_index) + ", "; - listing += "debug_name=" + node->debug_name + ", "; - listing += "created_in=" + node->created_in_pass + ", "; - listing += "created_in_func=" + node->created_in_function + ", "; - } -#endif - - listing += ")"; - - return listing; -} - -string GetOperationListing(const IR& ir, bool compact, map debug) { - // first give unique names to all the tensors - GenerateNodeNames(ir); - //ClusterProp clusters = ir.GetClusterProperties(); - - // now create the listing - int prev_depth = 0; - string listing; - for (auto node = ir.begin(); !node.end(); node.next()) { - if (compact) { - if (node->name == "const") continue; - } - - if (debug.contains(node.get())) { - listing += "[DEBUG] " + debug[node.get()] + ": \n"; - } - - // indent - int depth = node.depth() - 1; - //add scope brackets - if (depth < prev_depth) { - for (int i = prev_depth - 1; i >= depth; i--) { - for (int j = 0; j < i; j++) { - listing += " "; - } - listing += "}\n"; - } - } - else if (depth > prev_depth) { - for (int i = prev_depth; i < depth; i++) { - for (int j = 0; j < i; j++) { - listing += " "; - } - listing += "{\n"; - } - } - for (int i = 0; i < depth; i++) { - listing += " "; - } - prev_depth = depth; - - listing += GetNodeString(*node, true); - listing += "\n"; - } - - for (int i = prev_depth - 1; i >= 0; i--) { - for (int j = 0; j < i; j++) { - listing += " "; - } - listing += "}\n"; - } - - return listing; -} - -} // namespace TensorFrost diff --git a/TensorFrost/Backend/KernelManager.cpp b/TensorFrost/Backend/KernelManager.cpp deleted file mode 100644 index bf06a59b..00000000 --- a/TensorFrost/Backend/KernelManager.cpp +++ /dev/null @@ -1,38 +0,0 @@ -#include "KernelManager.h" - -namespace TensorFrost { -void KernelManager::AddKernelID(Program *program, Kernel *kernel) { - programs.insert(program); - kernel->kernel_id_ = global_kernel_id++; - kernel_map[kernel->kernel_id_] = kernel; -} - -vector KernelManager::GetAllMainFunctions() { - vector main_functions; - for (auto& program : programs) { - main_functions.push_back(program->main_function_); - } - return main_functions; -} - -vector, vector>>> KernelManager::GetAllKernels() { - vector, vector>>> kernels; - kernels.resize(kernel_map.size()); - for (auto& kernel : kernel_map) { - vector> args; - map memory_bindings = kernel.second->GetMemoryBindings(); - args.resize(memory_bindings.size()); - for (auto& [mem_node, binding] : memory_bindings) { - string name = mem_node->var_name + "_mem"; - string type_name = "uint"; - args[binding] = {name, type_name}; - } - tuple code = {kernel.second->generated_header_, kernel.second->generated_bindings_, kernel.second->generated_main_}; - kernels[kernel.first] = {code, args}; - } - return kernels; -} - -KernelManager* global_kernel_manager = nullptr; - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/KernelManager.h b/TensorFrost/Backend/KernelManager.h deleted file mode 100644 index 172788c8..00000000 --- a/TensorFrost/Backend/KernelManager.h +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "Compiler/KernelGen.h" -#include "TensorMemory.h" - -namespace TensorFrost { - -class KernelManager -{ - unordered_set programs; - unordered_map kernel_map; - size_t global_kernel_id = 0; - public: - - KernelManager() = default; - virtual void DispatchKernel(TFDispatchInfo info) = 0; - void AddKernelID(Program* program, Kernel* kernel); - vector GetAllMainFunctions(); - - vector, vector>>> GetAllKernels(); - Kernel* GetKernel(size_t kernel_id) { return kernel_map[kernel_id]; } -}; - -extern KernelManager* global_kernel_manager; - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/RenderDoc.cpp b/TensorFrost/Backend/RenderDoc.cpp deleted file mode 100644 index cd26b747..00000000 --- a/TensorFrost/Backend/RenderDoc.cpp +++ /dev/null @@ -1,56 +0,0 @@ -#include "RenderDoc.h" -#include -#include -#if defined(_WIN32) -#include -#else -#include -#include -#include -#endif - -namespace TensorFrost { - RENDERDOC_API_1_4_2* RDCAPI = nullptr; - - void LoadRDCAPI() - { - if (RDCAPI) - return; - -#if defined(_WIN32) - if (HMODULE mod = GetModuleHandleA("renderdoc.dll")) - { - std::cout << "renderdoc.dll successfully loaded" << std::endl; - pRENDERDOC_GetAPI RENDERDOC_GetAPI = (pRENDERDOC_GetAPI)GetProcAddress(mod, "RENDERDOC_GetAPI"); - RENDERDOC_GetAPI(eRENDERDOC_API_Version_1_4_2, (void**)&RDCAPI); - } -#else - if (void* mod = dlopen("librenderdoc.so", RTLD_NOW | RTLD_NOLOAD)) - { - std::cout << "librenderdoc.so successfully loaded" << std::endl; - pRENDERDOC_GetAPI RENDERDOC_GetAPI = (pRENDERDOC_GetAPI)dlsym(mod, "RENDERDOC_GetAPI"); - RENDERDOC_GetAPI(eRENDERDOC_API_Version_1_4_2, (void**)&RDCAPI); - } -#endif - } - - void StartRenderDocCapture() - { - LoadRDCAPI(); - - if (RDCAPI) - { - std::cout << "RenderDoc capture started" << std::endl; - RDCAPI->StartFrameCapture(NULL, NULL); - } - } - - void EndRenderDocCapture() - { - if (RDCAPI) - { - std::cout << "RenderDoc capture ended" << std::endl; - RDCAPI->EndFrameCapture(NULL, NULL); - } - } -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/RenderDoc.h b/TensorFrost/Backend/RenderDoc.h deleted file mode 100644 index 9a9f2f2d..00000000 --- a/TensorFrost/Backend/RenderDoc.h +++ /dev/null @@ -1,8 +0,0 @@ -#pragma once - -namespace TensorFrost { - -void StartRenderDocCapture(); -void EndRenderDocCapture(); - -} // namespace TensorFrost diff --git a/TensorFrost/Backend/TensorMemory.cpp b/TensorFrost/Backend/TensorMemory.cpp deleted file mode 100644 index 9f439610..00000000 --- a/TensorFrost/Backend/TensorMemory.cpp +++ /dev/null @@ -1,236 +0,0 @@ -#include "TensorMemory.h" - -namespace TensorFrost { - -size_t GetLinearSize(const vector& shape) { - size_t size = 1; - for (size_t dim : shape) { - size *= dim; - } - return size; -} - -vector GetShape(const TFTensor *tensor) { - vector shape; - for (size_t i = 0; i < tensor->dim; i++) { - shape.push_back(tensor->shape[i]); - } - return shape; -} - -size_t GetSize(const TFTensor *tensor) { - size_t size = 1; - for (size_t i = 0; i < tensor->dim; i++) { - size *= tensor->shape[i]; - } - return size; -} - -TFBuffer * TensorMemoryManager::AllocateBuffer(size_t size) { - TFBuffer* buffer = CreateBuffer(size); - if(allocation_history.contains(size)) { - size_t old_delay = GetDeallocationDelay(size); - allocation_delay[size] = std::min(std::max(old_delay, tick - allocation_history[size]), MAX_POSSIBLE_UNUSED_TIME); - } else { - allocation_delay[size] = DEFAULT_MAX_UNUSED_TIME; - } - buffers_created++; - //add the buffer to the list of allocated buffers - allocated_buffers[size].insert(buffer); - allocation_history[size] = tick; - return buffer; -} - -TFTensor * TensorMemoryManager::AllocateTensor(const vector &shape, const TFDataFormat type, const char* name) { - size_t size = GetLinearSize(shape); - - if (size == 0) { - throw invalid_argument("Trying to allocate a tensor with size 0"); - } - - TFBuffer* buf = TryAllocateBuffer(size); - buf->read_only = false; - ((TFBufferTemplate*)buf)->UpdateName(name); - return MakeTensor(shape, buf, type); -} - -TFTensor * TensorMemoryManager::AllocateTensorWithData(const vector &shape, const vector &data, - const TFDataFormat type, bool read_only, const char* name) { - TFTensor* tensor_memory = AllocateTensor(shape, type, name); - tensor_memory->buffer->read_only = read_only; - ((TFBufferTemplate*)tensor_memory->buffer)->SetDataAtOffset(0, data); - return tensor_memory; -} - -void TensorMemoryManager::DeallocateTensor(TFTensor tensor) { - DeallocateBuffer(tensor.buffer); -} - -TFTensor * TensorMemoryManager::MakeTensor(size_t *shape, size_t dim, TFBuffer *buf, TFDataFormat type) { - TFTensor* tensor = new TFTensor(); - tensor->buffer = buf; - tensor->dim = dim; - tensor->shape = shape; - tensor->format = type; - return tensor; -} - -TFTensor * TensorMemoryManager::MakeTensor(const vector &shape, TFBuffer *buf, TFDataFormat type) { - size_t* shape_arr = new size_t[shape.size()]; - std::copy(shape.begin(), shape.end(), shape_arr); - return MakeTensor(shape_arr, shape.size(), buf, type); -} - -size_t TensorMemoryManager::GetAllocatedSize() const { - size_t total = 0; - for(auto& [size, buffers]: allocated_buffers) { - total += size * buffers.size(); - } - return total; -} - -size_t TensorMemoryManager::GetUnusedAllocatedSize() const { - size_t total = 0; - for(auto& [size, buffers]: allocated_buffers) { - for(auto& buffer: buffers) { - if(unused_buffers.contains(buffer)) { - total += size; - } - } - } - return total; -} - -void TensorMemoryManager::DeallocateBuffer(TFBuffer *buffer) { - unused_buffers.insert(buffer); - buffer->time_since_used = 0; - buffer->used_size = 0; - buffer->up_to_date = false; - buffer->name = "none"; -} - -void TensorMemoryManager::RemoveBuffer(TFBuffer *buffer) { - size_t size = buffer->size; - allocated_buffers[size].erase(buffer); - unused_buffers.erase(buffer); - DeleteBuffer(buffer); - buffers_removed++; -} - -//#define READBACK_DEBUG - -vector TensorMemoryManager::Readback(const TFTensor *memory) { - vector data(GetSize(memory)); -#ifdef READBACK_DEBUG - cout << "Reading back " << data.size() << " elements from buffer of size " << memory->buffer->size << endl; -#endif - ((TFBufferTemplate*)memory->buffer)->GetDataAtOffset(0, data.size(), data.data()); - return data; -} - -uint TensorMemoryManager::ReadbackValue(const TFTensor *memory, size_t index) { - uint32_t data; - ((TFBufferTemplate*)memory->buffer)->GetDataAtOffset(index, 1, &data); - return data; -} - -void TensorMemoryManager::Writeback(const TFTensor *memory, const vector &data) { - ((TFBufferTemplate*)memory->buffer)->SetDataAtOffset(0, data); -} - -void TensorMemoryManager::WritebackValue(const TFTensor *memory, size_t index, uint32_t value) { - ((TFBufferTemplate*)memory->buffer)->SetDataAtOffset(index, {value}); -} - -void TensorMemoryManager::UpdateTick() { - unordered_set buffers_to_delete; - - for(auto& buffer: unused_buffers) { - size_t buf_size = buffer->size; - if(buffer->time_since_used > GetDeallocationDelay(buf_size)) { - buffers_to_delete.insert(buffer); - } else { - buffer->time_since_used++; - } - } - - //delete all buffers that are marked for deletion - for(auto& buffer: buffers_to_delete) { - RemoveBuffer(buffer); - } - tick++; - -#ifdef DEBUG_DYNAMIC_ALLOCATION - if(tick%2048 == 0) { - if(buffers_created> 0 || buffers_removed > 0) { - cout << "Note: " << buffers_created << " buffers created and " << buffers_removed << " buffers removed in the last 2048 ticks" << endl; - } - buffers_created = 0; - buffers_removed = 0; - } -#endif -} - -size_t TensorMemoryManager::GetDeallocationDelay(size_t buf_size) const { - if(allocation_delay.contains(buf_size)) { - return allocation_delay.at(buf_size); - } - - return DEFAULT_MAX_UNUSED_TIME; -} - -TFBuffer *TensorMemoryManager::TryAllocateBuffer(size_t size) { - //try to find a non-used buffer of the correct size - TFBuffer* buffer = nullptr; - bool found = false; - //find the smallest buffer that is larger than the requested size - size_t min_size = size; - size_t max_size = 16 * size; - //get iterator to the first buffer that is larger than the requested size - auto it = allocated_buffers.lower_bound(min_size); - //if no buffer is larger than the requested size, get the first buffer - if(it == allocated_buffers.end()) { - it = allocated_buffers.begin(); - } - //iterate through the buffers - for(; it != allocated_buffers.end(); it++) { - if(it->first > max_size) { - break; - } - if(it->first < size) { - continue; - } - for(auto buf: it->second) { - if(!unused_buffers.contains(buf)) { - continue; - } - buffer = buf; - found = true; - } - if(found) { - break; - } - } - //if no buffer was found, create a new one - if(!found) { - buffer = AllocateBuffer(size); - } else { - unused_buffers.erase(buffer); - } - buffer->used_size = size; - buffer->time_since_used = 0; - UpdateTick(); - return buffer; -} - -TensorMemoryManager::~TensorMemoryManager() { - for(auto& [size, buffers]: allocated_buffers) { - for(auto& buffer: buffers) { - DeleteBuffer(buffer); - } - } -} - -TensorMemoryManager* global_memory_manager = nullptr; - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/TensorMemory.h b/TensorFrost/Backend/TensorMemory.h deleted file mode 100644 index 546d3b37..00000000 --- a/TensorFrost/Backend/TensorMemory.h +++ /dev/null @@ -1,146 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "../Tensor/Tensor.h" - -//#define DEBUG_DYNAMIC_ALLOCATION - -namespace TensorFrost { - -using namespace std; - -extern "C" { - struct TFBuffer { - size_t size = 0; - size_t used_size = 0; - size_t time_since_used = 0; - bool up_to_date = false; - bool read_only = false; - const char* name = nullptr; - //add type descriptor (for special kinds of buffers) - }; - - struct TFTensor { - TFBuffer* buffer; - TFDataFormat format; - size_t dim; - const size_t* shape; - }; - - struct TFTensorList { - size_t count; - const TFTensor* tensors; - }; - - struct TFDispatchInfo { - size_t kernel_id; - size_t read_write_count; - const TFTensor* read_write_tensors; - size_t read_only_count; - const TFTensor* read_only_tensors; - size_t variable_count; - const uint32_t* variables; - size_t work_group_count; - }; - - typedef TFTensor alloc_func(const char*, const size_t*, size_t, TFDataFormat, void*); - typedef void dealloc_func(TFTensor, void*); - typedef uint readback_func(TFTensor, size_t, void*); - typedef void writeback_func(TFTensor, size_t, uint32_t, void*); - typedef void dispatch_func(TFDispatchInfo, void*); - typedef void region_func(const char*, bool, void*); - - struct TFRuntime { - alloc_func* alloc; - dealloc_func* dealloc; - readback_func* readback; - writeback_func* writeback; - dispatch_func* dispatch; - region_func* region; - void* custom_data; - }; - - typedef void cpu_dispatch_func(const uint32_t* var, uint32_t** mem, uint work_group_count); - typedef void main_func(TFTensor*, TFTensor*, TFRuntime); -} - -class TFBufferTemplate : public -TFBuffer { -public: - TFBufferTemplate(size_t size) : TFBuffer{ size } {} - - virtual void UpdateName(const char* name) { - throw std::runtime_error("UpdateName not implemented"); - } - virtual void SetDataAtOffset(size_t offset, const vector& data) { - throw std::runtime_error("SetDataAtOffset not implemented"); - } - virtual void GetDataAtOffset(size_t offset, size_t size, uint32_t* data) { - throw std::runtime_error("GetDataAtOffset not implemented"); - } -}; - -using uint = unsigned int; - -size_t GetLinearSize(const vector& shape); -vector GetShape(const TFTensor* tensor); -size_t GetSize(const TFTensor* tensor); - -class TensorMemoryManager { -private: - size_t tick = 0; - size_t buffers_created = 0; - size_t buffers_removed = 0; - const size_t DEFAULT_MAX_UNUSED_TIME = 128; - const size_t MAX_POSSIBLE_UNUSED_TIME = 32768; - map> allocated_buffers; - map allocation_history; //stores the last tick when a buffer of a certain size was allocated - map allocation_delay; //stores the time between the last 2 allocations of a buffer size - unordered_set unused_buffers; - - static TFTensor* MakeTensor(size_t* shape, size_t dim, TFBuffer* buf, TFDataFormat type); - static TFTensor* MakeTensor(const vector& shape, TFBuffer* buf, TFDataFormat type); - void UpdateTick(); - size_t GetDeallocationDelay(size_t size) const; - - TFBuffer* AllocateBuffer(size_t size); - TFBuffer* TryAllocateBuffer(size_t size); - void DeallocateBuffer(TFBuffer* buffer); - void RemoveBuffer(TFBuffer* buffer); - -protected: - virtual TFBuffer* CreateBuffer(size_t size) { - throw std::runtime_error("CreateBuffer not implemented"); - } - - virtual void DeleteBuffer(TFBuffer * buffer) { - throw std::runtime_error("DeleteBuffer not implemented"); - } - -public: - virtual vector Readback(const TFTensor* memory); - virtual uint ReadbackValue(const TFTensor* memory, size_t index); - virtual void Writeback(const TFTensor* memory, const vector& data); - virtual void WritebackValue(const TFTensor* memory, size_t index, uint32_t value); - - TFTensor* AllocateTensor(const vector& shape, const TFDataFormat type = TFTypeFloat32, const char* name = nullptr); - TFTensor* AllocateTensorWithData(const vector& shape, const vector& data, const TFDataFormat type = TFTypeFloat32, bool read_only = false, const char* name = nullptr); - void DeallocateTensor(TFTensor tensor); - - size_t GetAllocatedSize() const; - size_t GetUnusedAllocatedSize() const; - - ~TensorMemoryManager(); -}; - - -extern TensorMemoryManager* global_memory_manager; - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Backend/include/Backend/RenderDoc.h b/TensorFrost/Backend/include/Backend/RenderDoc.h new file mode 100644 index 00000000..a89e14df --- /dev/null +++ b/TensorFrost/Backend/include/Backend/RenderDoc.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace TensorFrost { + +void StartRenderDocCapture(); +std::string EndRenderDocCapture(bool launchReplayUI = false); +bool IsRenderDocAvailable(); + +} // namespace TensorFrost diff --git a/TensorFrost/Backend/include/Backend/Vulkan.h b/TensorFrost/Backend/include/Backend/Vulkan.h new file mode 100644 index 00000000..d61041d2 --- /dev/null +++ b/TensorFrost/Backend/include/Backend/Vulkan.h @@ -0,0 +1,104 @@ +#pragma once +#define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1 +#include +#include +#include +#include +#include + +struct Buffer { + vk::Buffer buffer; + vk::DeviceMemory memory; + size_t size; +}; + +struct ComputeProgram { + vk::ShaderModule shaderModule; + vk::DescriptorSetLayout descriptorLayout; + vk::PipelineLayout pipelineLayout; + vk::Pipeline pipeline; + uint32_t numRO = 0, numRW = 0; + uint32_t pushConstantSize = 0; +}; + +struct ComputeBindings { + vk::DescriptorSet set{}; +}; + +// --- cache key + hash (as before) +struct DSKey { + vk::DescriptorSetLayout layout{}; + std::vector bufs; + std::vector sizes; + bool operator==(const DSKey& o) const { + return layout==o.layout && bufs==o.bufs && sizes==o.sizes; + } +}; +struct DSKeyHash { + size_t operator()(DSKey const& k) const noexcept { + auto mix = [](size_t& s, uint64_t v){ s ^= std::hash{}(v) + 0x9e3779b97f4a7c15ULL + (s<<6) + (s>>2); }; + size_t seed = 0; + mix(seed, (uint64_t)VkDescriptorSetLayout(k.layout)); + for (auto b : k.bufs) mix(seed, (uint64_t)b); + for (auto sz: k.sizes) mix(seed, (uint64_t)sz); + return seed; + } +}; + +// --- cached entry +struct CachedDS { + vk::DescriptorSet set{}; + std::vector buffers; // for invalidation + uint64_t lastUseTick = 0; + uint64_t useCount = 0; +}; + + +struct CachedBuf { + Buffer buf; + uint64_t lastUse=0; + uint64_t useCount=0; +}; + +// Holds instance, physical device, logical device, queue and command pool. +struct VulkanContext { + vk::Instance instance; + vk::PhysicalDevice physicalDevice; + uint32_t queueFamilyIndex = UINT32_MAX; // compute + uint32_t graphicsFamilyIndex = UINT32_MAX; // present candidate + vk::Device device; + vk::Queue computeQueue; + vk::Queue presentQueue; + vk::CommandPool commandPool; + + vk::DescriptorPool descriptorPool; + std::unordered_map dsCache; + size_t dsCacheCapacity = 256; // must be ≤ pool maxSets + uint64_t dsUseTick = 1; + + std::unordered_map> bufferCache; // key: exact byte size + size_t bufferCacheCapacity = 128; // total entries across all sizes + uint64_t bufferUseTick = 1; + + VulkanContext(); + ~VulkanContext(); +}; + +Buffer createBuffer(size_t count, size_t dtypeSize, bool readOnly); +void destroyBuffer(Buffer& buf); +void setBufferData(Buffer& buf, const void* src, size_t bytes, size_t offset = 0); +void getBufferData(const Buffer& buf, void* dst, size_t bytes, size_t offset = 0); + +ComputeProgram createComputeProgramFromSlang(const std::string& moduleName, + const std::string& source, const std::string& entry, + uint32_t roCount, uint32_t rwCount, uint32_t pushConstantSize = 0); +void destroyComputeProgram(ComputeProgram& prog); + +void runProgram(const ComputeProgram& prog, + const std::vector& readonlyBuffers, + const std::vector& readwriteBuffers, + uint32_t groupCount, + const void* pushConstants, + size_t pushConstantSize); + +VulkanContext& getVulkanContext(); diff --git a/TensorFrost/Backend/include/Backend/Window.h b/TensorFrost/Backend/include/Backend/Window.h new file mode 100644 index 00000000..4414cee7 --- /dev/null +++ b/TensorFrost/Backend/include/Backend/Window.h @@ -0,0 +1,145 @@ +#pragma once +#include "Vulkan.h" +#include +#include +#include +#include + + +struct WindowContext; +void ReleaseImGui(WindowContext& ctx); +struct ImGuiContext; + +namespace TFWindowDetail { +void RegisterScrollContext(GLFWwindow* wnd, WindowContext* ctx); +void UnregisterScrollContext(GLFWwindow* wnd); +} + +struct WindowContext { + GLFWwindow* wnd{}; + vk::Instance instance; + vk::PhysicalDevice phys; + vk::Device device; + uint32_t presentFam{}; + vk::Queue queue; + + vk::SurfaceKHR surface; + vk::SwapchainKHR swapchain; + std::vector images; + vk::Format format{}; + vk::Extent2D extent{}; + vk::CommandPool pool; + vk::CommandBuffer cmd; + vk::Semaphore semImage, semDone; + vk::Fence fence; + + // ImGui integration helpers + vk::DescriptorPool imguiPool{}; + vk::RenderPass renderPass{}; + std::vector imageViews; + std::vector framebuffers; + ImGuiContext* imguiContext{}; + bool imguiFrameActive = false; + double scrollDeltaX = 0.0; + double scrollDeltaY = 0.0; + GLFWscrollfun prevScrollCallback = nullptr; + + WindowContext() = default; + WindowContext(const WindowContext&) = delete; + WindowContext& operator=(const WindowContext&) = delete; + + WindowContext(WindowContext&& o) noexcept { moveFrom(std::move(o)); } + WindowContext& operator=(WindowContext&& o) noexcept { + if (this != &o) { cleanup(); moveFrom(std::move(o)); } + return *this; + } + + ~WindowContext() { cleanup(); } + +private: + void moveFrom(WindowContext&& o) { + wnd=o.wnd; o.wnd=nullptr; + instance=o.instance; o.instance=nullptr; + phys=o.phys; o.phys=nullptr; + device=o.device; o.device=nullptr; + presentFam=o.presentFam; o.presentFam=0; + queue=o.queue; o.queue=nullptr; + surface=o.surface; o.surface=nullptr; + swapchain=o.swapchain; o.swapchain=nullptr; + images=std::move(o.images); + format=o.format; o.format=vk::Format{}; + extent=o.extent; o.extent=vk::Extent2D{}; + pool=o.pool; o.pool=nullptr; + cmd=o.cmd; o.cmd=nullptr; + semImage=o.semImage; o.semImage=nullptr; + semDone=o.semDone; o.semDone=nullptr; + fence=o.fence; o.fence=nullptr; + imguiPool=o.imguiPool; o.imguiPool=nullptr; + renderPass=o.renderPass; o.renderPass=nullptr; + imageViews=std::move(o.imageViews); + framebuffers=std::move(o.framebuffers); + imguiContext=o.imguiContext; o.imguiContext=nullptr; + imguiFrameActive=o.imguiFrameActive; o.imguiFrameActive=false; + scrollDeltaX = o.scrollDeltaX; o.scrollDeltaX = 0.0; + scrollDeltaY = o.scrollDeltaY; o.scrollDeltaY = 0.0; + prevScrollCallback = o.prevScrollCallback; o.prevScrollCallback = nullptr; + if (wnd) { + TFWindowDetail::RegisterScrollContext(wnd, this); + } + } + + void cleanup() { + if (!wnd && !device) return; // already moved/clean + // don’t terminate GLFW here; only destroy this window + if (device) { + ReleaseImGui(*this); + + (void)device.waitIdle(); + for (auto fb : framebuffers) device.destroyFramebuffer(fb); + framebuffers.clear(); + for (auto view : imageViews) device.destroyImageView(view); + imageViews.clear(); + if (imguiPool) { device.destroyDescriptorPool(imguiPool); imguiPool=nullptr; } + if (renderPass) { device.destroyRenderPass(renderPass); renderPass=nullptr; } + if (fence) device.destroyFence(fence), fence=nullptr; + if (semDone) device.destroySemaphore(semDone), semDone=nullptr; + if (semImage) device.destroySemaphore(semImage), semImage=nullptr; + if (cmd) device.freeCommandBuffers(pool, cmd), cmd=nullptr; + if (pool) device.destroyCommandPool(pool), pool=nullptr; + if (swapchain) device.destroySwapchainKHR(swapchain), swapchain=nullptr; + } + if (surface) instance.destroySurfaceKHR(surface), surface=nullptr; + if (wnd) { + TFWindowDetail::UnregisterScrollContext(wnd); + glfwDestroyWindow(wnd); + wnd=nullptr; + } + // leave GLFW alive; app can call glfwTerminate() once at shutdown if it wants + device=nullptr; instance=nullptr; queue=nullptr; + scrollDeltaX = 0.0; + scrollDeltaY = 0.0; + prevScrollCallback = nullptr; + } +}; + +WindowContext createWindow(int width, int height, const char* title); +bool windowOpen(const WindowContext& ctx); +void drawBuffer(WindowContext& ctx, vk::Buffer src, uint32_t width, uint32_t height, vk::DeviceSize offset = 0); +void drawBuffer(WindowContext& ctx, const Buffer& b, uint32_t w, uint32_t h, size_t offset = 0); +void AttachWindowCallbacks(WindowContext& ctx); + +// Global helpers used by higher-level integrations / Python bindings +WindowContext* GetWindow(); +WindowContext& RequireWindow(); +ImGuiContext* GetImGuiContext(); +void EnsureImGuiFrame(WindowContext& ctx); + +void ShowWindow(int width, int height, const char* title); +void HideWindow(); +void RenderFrame(const Buffer* buffer, uint32_t width, uint32_t height, size_t offset = 0); +void RenderFrame(const Buffer* buffer = nullptr); +bool WindowShouldClose(); +std::pair GetMousePosition(); +std::pair GetWindowSize(); +bool IsMouseButtonPressed(int button); +bool IsKeyPressed(int key); diff --git a/TensorFrost/Backend/src/RenderDoc.cpp b/TensorFrost/Backend/src/RenderDoc.cpp new file mode 100644 index 00000000..0b5fbc7f --- /dev/null +++ b/TensorFrost/Backend/src/RenderDoc.cpp @@ -0,0 +1,157 @@ +#include "Backend/RenderDoc.h" + +#include + +#include +#include +#include +#include +#include + +#include "Backend/Vulkan.h" + +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#include +#else +#include +#endif + +namespace TensorFrost { +namespace { +RENDERDOC_API_1_4_2* gRenderDocApi = nullptr; + +void LoadRenderDoc() +{ + static bool loggedUnavailable = false; + + if (gRenderDocApi) { + return; + } + +#if defined(_WIN32) + const char* moduleNames[] = {"renderdoc.dll", "renderdoccmd.dll"}; + + for (const char* moduleName : moduleNames) { + HMODULE mod = GetModuleHandleA(moduleName); + if (!mod) { + continue; + } + + auto getApi = reinterpret_cast(GetProcAddress(mod, "RENDERDOC_GetAPI")); + if (getApi && + getApi(eRENDERDOC_API_Version_1_4_2, reinterpret_cast(&gRenderDocApi)) == 1) { + std::cout << "RenderDoc API loaded from " << moduleName << std::endl; + loggedUnavailable = false; + break; + } + } +#else + if (void* mod = dlopen("librenderdoc.so", RTLD_NOW | RTLD_NOLOAD)) { + auto getApi = reinterpret_cast(dlsym(mod, "RENDERDOC_GetAPI")); + if (getApi && getApi(eRENDERDOC_API_Version_1_4_2, reinterpret_cast(&gRenderDocApi)) == 1) { + std::cout << "RenderDoc API loaded" << std::endl; + loggedUnavailable = false; + dlclose(mod); + return; + } + dlclose(mod); + } +#endif + + if (!gRenderDocApi && !loggedUnavailable) { + std::cout << "RenderDoc API not available" << std::endl; + loggedUnavailable = true; + } +} +} // namespace + +bool IsRenderDocAvailable() +{ + LoadRenderDoc(); + return gRenderDocApi != nullptr; +} + +void StartRenderDocCapture() +{ + if (!IsRenderDocAvailable()) { + std::cout << "RenderDoc not available; start capture skipped" << std::endl; + return; + } + + RENDERDOC_DevicePointer deviceHandle = nullptr; + try { + auto& ctx = getVulkanContext(); + VkInstance vkInstance = static_cast(ctx.instance); + deviceHandle = RENDERDOC_DEVICEPOINTER_FROM_VKINSTANCE(vkInstance); + } catch (const std::exception& e) { + std::cout << "RenderDoc capture start warning: failed to get Vulkan instance (" << e.what() << ")" << std::endl; + } + + gRenderDocApi->StartFrameCapture(deviceHandle, nullptr); + if (gRenderDocApi->IsFrameCapturing && gRenderDocApi->IsFrameCapturing() == 1) { + std::cout << "RenderDoc capture start requested" << std::endl; + } else { + std::cout << "RenderDoc capture start did not begin (IsFrameCapturing=0)" << std::endl; + } +} + +std::string EndRenderDocCapture(bool launchReplayUI) +{ + if (!IsRenderDocAvailable()) { + std::cout << "RenderDoc not available; end capture skipped" << std::endl; + return {}; + } + + RENDERDOC_DevicePointer deviceHandle = nullptr; + try { + auto& ctx = getVulkanContext(); + VkInstance vkInstance = static_cast(ctx.instance); + deviceHandle = RENDERDOC_DEVICEPOINTER_FROM_VKINSTANCE(vkInstance); + } catch (const std::exception& e) { + std::cout << "RenderDoc capture end warning: failed to get Vulkan instance (" << e.what() << ")" << std::endl; + } + + std::string capturePath; + const uint32_t result = gRenderDocApi->EndFrameCapture(deviceHandle, nullptr); + if (result == 1) { + std::cout << "RenderDoc capture end requested" << std::endl; + if (gRenderDocApi->GetNumCaptures && gRenderDocApi->GetCapture) { + const uint32_t numCaptures = gRenderDocApi->GetNumCaptures(); + if (numCaptures > 0) { + const uint32_t captureIndex = numCaptures - 1; + uint32_t pathLength = 0; + if (gRenderDocApi->GetCapture(captureIndex, nullptr, &pathLength, nullptr) == 1 && pathLength > 0) { + capturePath.resize(pathLength); + if (gRenderDocApi->GetCapture(captureIndex, capturePath.data(), &pathLength, nullptr) == 1) { + if (!capturePath.empty() && capturePath.back() == '\0') { + capturePath.pop_back(); + } + std::cout << "RenderDoc capture stored at: " << capturePath << std::endl; + if (launchReplayUI && !capturePath.empty()) { + if (gRenderDocApi->LaunchReplayUI) { + const std::string commandLine = "\"" + capturePath + "\""; + const uint32_t pid = gRenderDocApi->LaunchReplayUI(0, commandLine.c_str()); + if (pid == 0) { + std::cout << "RenderDoc replay UI launch failed" << std::endl; + } else { + std::cout << "RenderDoc replay UI launched (PID " << pid << ")" << std::endl; + } + } else { + std::cout << "RenderDoc replay UI launch unavailable (API function missing)" << std::endl; + } + } + } else { + capturePath.clear(); + } + } + } + } + } else { + std::cout << "RenderDoc capture end failed (" << result << ")" << std::endl; + } + + return capturePath; +} + +} // namespace TensorFrost diff --git a/TensorFrost/Backend/src/Vulkan.cpp b/TensorFrost/Backend/src/Vulkan.cpp new file mode 100644 index 00000000..11794e1a --- /dev/null +++ b/TensorFrost/Backend/src/Vulkan.cpp @@ -0,0 +1,636 @@ +#include "Backend/Vulkan.h" +VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE +#include +#include +#include +#include +#include +#include + +namespace { +VulkanContext& getOrCreateGlobalContext() { + static VulkanContext ctx{}; + return ctx; +} +} + +VulkanContext& getVulkanContext() { + return getOrCreateGlobalContext(); +} + +VulkanContext::VulkanContext() { + if (!glfwInit()) throw std::runtime_error("GLFW init failed"); + if (!glfwVulkanSupported()) + throw std::runtime_error("GLFW: Vulkan loader not found. Install a Vulkan-capable GPU driver or the Vulkan SDK."); + + VULKAN_HPP_DEFAULT_DISPATCHER.init(vkGetInstanceProcAddr); + + uint32_t extCount = 0; + const char** extNames = glfwGetRequiredInstanceExtensions(&extCount); + if (!extNames || extCount == 0) throw std::runtime_error("GLFW: Vulkan not supported"); + + vk::ApplicationInfo appInfo("TensorFrost", 1, nullptr, 0, VK_API_VERSION_1_2); + vk::InstanceCreateInfo instCreate({}, &appInfo, 0, nullptr, extCount, extNames); + instance = vk::createInstance(instCreate); + VULKAN_HPP_DEFAULT_DISPATCHER.init(instance); + + // pick device + queue families (compute + graphics) + auto devices = instance.enumeratePhysicalDevices(); + if (devices.empty()) throw std::runtime_error("No physical devices"); + + for (auto& pd : devices) { + auto q = pd.getQueueFamilyProperties(); + int compute = -1, graphics = -1; + for (uint32_t i = 0; i < q.size(); ++i) { + auto f = q[i].queueFlags; + if (compute < 0 && (f & vk::QueueFlagBits::eCompute)) compute = int(i); + if (graphics < 0 && (f & vk::QueueFlagBits::eGraphics)) graphics = int(i); + } + if (compute >= 0 && graphics >= 0) { + physicalDevice = pd; + queueFamilyIndex = uint32_t(compute); + graphicsFamilyIndex = uint32_t(graphics); + break; + } + } + if (!physicalDevice) throw std::runtime_error("No suitable device"); + + // device with both queues + swapchain + float prio = 1.0f; + std::vector queues; + queues.emplace_back(vk::DeviceQueueCreateInfo({}, queueFamilyIndex, 1, &prio)); + if (graphicsFamilyIndex != queueFamilyIndex) + queues.emplace_back(vk::DeviceQueueCreateInfo({}, graphicsFamilyIndex, 1, &prio)); + + // enable VK_KHR_swapchain (required for window/swapchain) + const char* devExts[] = { VK_KHR_SWAPCHAIN_EXTENSION_NAME }; + vk::DeviceCreateInfo devCreate({}, (uint32_t)queues.size(), queues.data(), + 0, nullptr, 1, devExts); + device = physicalDevice.createDevice(devCreate); + VULKAN_HPP_DEFAULT_DISPATCHER.init(device); + + computeQueue = device.getQueue(queueFamilyIndex, 0); + presentQueue = device.getQueue(graphicsFamilyIndex, 0); + + vk::CommandPoolCreateInfo poolInfo({}, queueFamilyIndex); + commandPool = device.createCommandPool(poolInfo); + + vk::DescriptorPoolSize sz(vk::DescriptorType::eStorageBuffer, 1024); + vk::DescriptorPoolCreateInfo dp(vk::DescriptorPoolCreateFlagBits::eFreeDescriptorSet, 256, 1, &sz); + descriptorPool = device.createDescriptorPool(dp); + dsCacheCapacity = 256; +} + +static void evictSome(VulkanContext& ctx, size_t n) { + if (ctx.dsCache.empty() || n == 0) return; + std::vector::iterator> items; + items.reserve(ctx.dsCache.size()); + for (auto it = ctx.dsCache.begin(); it != ctx.dsCache.end(); ++it) items.push_back(it); + std::stable_sort(items.begin(), items.end(), + [](auto a, auto b){ return a->second.lastUseTick < b->second.lastUseTick; }); // LRU first + n = std::min(n, items.size()); + for (size_t i = 0; i < n; ++i) { + auto it = items[i]; + if (it->second.set) ctx.device.freeDescriptorSets(ctx.descriptorPool, 1, &it->second.set); + ctx.dsCache.erase(it); + } +} + +static void evictToCapacity(VulkanContext& ctx) { + if (ctx.dsCache.size() > ctx.dsCacheCapacity) + evictSome(ctx, ctx.dsCache.size() - ctx.dsCacheCapacity); +} + +static void invalidateDescriptorCacheForBuffer(VulkanContext& ctx, VkBuffer buf) { + std::vector dead; + for (auto it = ctx.dsCache.begin(); it != ctx.dsCache.end(); ++it) + if (std::find(it->second.buffers.begin(), it->second.buffers.end(), buf) != it->second.buffers.end()) + dead.push_back(it); + for (auto it : dead) { + if (it->second.set) ctx.device.freeDescriptorSets(ctx.descriptorPool, 1, &it->second.set); + ctx.dsCache.erase(it); + } +} + +static void invalidateDescriptorCacheForLayout(VulkanContext& ctx, vk::DescriptorSetLayout layout) { + std::vector dead; + for (auto it = ctx.dsCache.begin(); it != ctx.dsCache.end(); ++it) + if (it->first.layout == layout) dead.push_back(it); + for (auto it : dead) { + if (it->second.set) ctx.device.freeDescriptorSets(ctx.descriptorPool, 1, &it->second.set); + ctx.dsCache.erase(it); + } +} + +// optional knobs +void setDescriptorCacheCapacity(VulkanContext& ctx, size_t cap) { + ctx.dsCacheCapacity = cap; + evictToCapacity(ctx); +} +void clearDescriptorCache(VulkanContext& ctx) { + evictSome(ctx, ctx.dsCache.size()); +} + +// --- cached descriptor set retrieval +static vk::DescriptorSet getOrCreateSet(VulkanContext& ctx, const ComputeProgram& prog, + const std::vector& ro, const std::vector& rw) { + if (ro.size()!=prog.numRO || rw.size()!=prog.numRW) throw std::runtime_error("buffer count != program layout"); + + DSKey key; key.layout = prog.descriptorLayout; + key.bufs.reserve(ro.size()+rw.size()); + key.sizes.reserve(ro.size()+rw.size()); + for (auto* b : ro) { key.bufs.push_back(b->buffer); key.sizes.push_back(b->size); } + for (auto* b : rw) { key.bufs.push_back(b->buffer); key.sizes.push_back(b->size); } + + if (auto it = ctx.dsCache.find(key); it != ctx.dsCache.end()) { + it->second.lastUseTick = ++ctx.dsUseTick; + it->second.useCount++; + return it->second.set; + } + + evictToCapacity(ctx); + + vk::DescriptorSet set{}; + for (int attempt = 0; attempt < 3; ++attempt) { + try { + vk::DescriptorSetAllocateInfo ai(ctx.descriptorPool, 1, &prog.descriptorLayout); + set = ctx.device.allocateDescriptorSets(ai)[0]; + break; + } catch (const vk::SystemError& e) { + auto r = static_cast(e.code().value()); + if (r == vk::Result::eErrorOutOfPoolMemory || r == vk::Result::eErrorFragmentedPool) { + evictSome(ctx, std::max(1, ctx.dsCache.size()/4)); + } else { + throw; + } + } + } + if (!set) throw std::runtime_error("Descriptor set allocation failed"); + + std::vector infos; infos.reserve(key.bufs.size()); + for (auto* b : ro) infos.emplace_back(b->buffer, 0, b->size); + for (auto* b : rw) infos.emplace_back(b->buffer, 0, b->size); + + std::vector writes; writes.reserve(infos.size()); + for (uint32_t i=0;i nodes; nodes.reserve(bufferCacheCount(ctx)); + + for (auto& kv : ctx.bufferCache) + for (auto& e : kv.second) + nodes.push_back(Node{ kv.first, e.buf.buffer, e.lastUse }); + + std::stable_sort(nodes.begin(), nodes.end(), + [](const Node& a, const Node& b){ return a.last < b.last; }); + + if (n > nodes.size()) n = nodes.size(); + + for (size_t i = 0; i < n; ++i) { + auto itMap = ctx.bufferCache.find(nodes[i].key); + if (itMap == ctx.bufferCache.end()) continue; + + auto& vec = itMap->second; + + auto it = std::find_if(vec.begin(), vec.end(), + [&](const CachedBuf& e){ + return e.buf.buffer == nodes[i].buf; + }); + if (it == vec.end()) continue; // already evicted by an earlier iteration + + Buffer b = it->buf; + vec.erase(it); + if (vec.empty()) ctx.bufferCache.erase(itMap); + + if (b.buffer) ctx.device.destroyBuffer(b.buffer); + if (b.memory) ctx.device.freeMemory(b.memory); + } +} + +static void evictBuffersToCapacity(VulkanContext& ctx) { + size_t cnt = bufferCacheCount(ctx); + if (cnt > ctx.bufferCacheCapacity) evictBuffers(ctx, cnt - ctx.bufferCacheCapacity); +} + +void clearBufferCache(VulkanContext& ctx) { evictBuffers(ctx, bufferCacheCount(ctx)); } + +void setBufferCacheCapacity(VulkanContext& ctx, size_t cap) { ctx.bufferCacheCapacity = cap; evictBuffersToCapacity(ctx); } + +static bool takeBufferFromCache(VulkanContext& ctx, size_t bytes, Buffer& out) { + auto it = ctx.bufferCache.find(bytes); + if (it == ctx.bufferCache.end() || it->second.empty()) return false; + auto e = std::move(it->second.back()); + it->second.pop_back(); + if (it->second.empty()) ctx.bufferCache.erase(it); + out = e.buf; // handles copied; cache keeps no reference + return true; +} + +// create a storage buffer +Buffer createBuffer(size_t count, size_t dtypeSize, bool readOnly) { + auto& ctx = getVulkanContext(); + Buffer buf{}; + buf.size = count * dtypeSize; + + if (takeBufferFromCache(ctx, buf.size, buf)) return buf; + + vk::BufferCreateInfo bci({}, buf.size, + vk::BufferUsageFlagBits::eStorageBuffer | + vk::BufferUsageFlagBits::eTransferSrc | + vk::BufferUsageFlagBits::eTransferDst); + buf.buffer = ctx.device.createBuffer(bci); + auto memReq = ctx.device.getBufferMemoryRequirements(buf.buffer); + + auto memProps = ctx.physicalDevice.getMemoryProperties(); + uint32_t memTypeIndex = UINT32_MAX; + // Prefer device-local memory for GPU performance + for (uint32_t i = 0; i < memProps.memoryTypeCount; i++) { + if ((memReq.memoryTypeBits & (1u< buf.size) throw std::out_of_range("write out of range"); + if (bytes == 0) return; + + // Create a temporary host-visible staging buffer for upload + vk::BufferCreateInfo bci({}, bytes, vk::BufferUsageFlagBits::eTransferSrc); + vk::Buffer staging = ctx.device.createBuffer(bci); + auto memReq = ctx.device.getBufferMemoryRequirements(staging); + + auto memProps = ctx.physicalDevice.getMemoryProperties(); + uint32_t memTypeIndex = UINT32_MAX; + // Prefer HOST_VISIBLE | HOST_COHERENT for simple map without flush; fallback to HOST_VISIBLE + for (uint32_t i = 0; i < memProps.memoryTypeCount; i++) { + auto f = memProps.memoryTypes[i].propertyFlags; + if ((memReq.memoryTypeBits & (1u< buf.size) throw std::out_of_range("read out of range"); + if (bytes == 0) return; + + // Create a temporary host-visible staging buffer for download + vk::BufferCreateInfo bci({}, bytes, vk::BufferUsageFlagBits::eTransferDst); + vk::Buffer staging = ctx.device.createBuffer(bci); + auto memReq = ctx.device.getBufferMemoryRequirements(staging); + + auto memProps = ctx.physicalDevice.getMemoryProperties(); + uint32_t memTypeIndex = UINT32_MAX; + // Prefer HOST_VISIBLE | HOST_COHERENT; fallback to HOST_VISIBLE + for (uint32_t i = 0; i < memProps.memoryTypeCount; i++) { + auto f = memProps.memoryTypes[i].propertyFlags; + if ((memReq.memoryTypeBits & (1u<second.set) device.freeDescriptorSets(descriptorPool, 1, &it->second.set); + } + dsCache.clear(); + clearBufferCache(*this); + device.destroyDescriptorPool(descriptorPool); + device.destroyCommandPool(commandPool); + device.destroy(); + instance.destroy(); +} + +std::vector compileSlangToSpirv(const char* moduleName, + const char* source, + const char* entry, + const char* profile /* e.g., "spirv_1_5" */) { + Slang::ComPtr global; + createGlobalSession(global.writeRef()); + + slang::TargetDesc tgt{}; + tgt.format = SLANG_SPIRV; + tgt.profile = global->findProfile(profile); + + slang::SessionDesc sd{}; + sd.targets = &tgt; sd.targetCount = 1; +#if defined(_RELWITHDEBINFO) + std::array optionEntries{}; + optionEntries[0].name = slang::CompilerOptionName::DebugInformation; + optionEntries[0].value.kind = slang::CompilerOptionValueKind::Int; + optionEntries[0].value.intValue0 = SLANG_DEBUG_INFO_LEVEL_STANDARD; + optionEntries[1].name = slang::CompilerOptionName::Optimization; + optionEntries[1].value.kind = slang::CompilerOptionValueKind::Int; + optionEntries[1].value.intValue0 = SLANG_OPTIMIZATION_LEVEL_NONE; + sd.compilerOptionEntries = optionEntries.data(); + sd.compilerOptionEntryCount = static_cast(optionEntries.size()); +#endif + + Slang::ComPtr session; + global->createSession(sd, session.writeRef()); + + Slang::ComPtr diag; + Slang::ComPtr mod; + mod = session->loadModuleFromSourceString(moduleName, moduleName, source, diag.writeRef()); + if (diag && diag->getBufferSize()) std::fprintf(stderr, "%s\n", (const char*)diag->getBufferPointer()); + if (!mod) throw std::runtime_error("slang: module load failed"); + + Slang::ComPtr ep; + mod->findEntryPointByName(entry, ep.writeRef()); + if (!ep) throw std::runtime_error("slang: entry not found"); + + slang::IComponentType* parts[] = { mod.get(), ep.get() }; + Slang::ComPtr composed, linked; + + { + Slang::ComPtr d; + SlangResult r = session->createCompositeComponentType(parts, 2, composed.writeRef(), d.writeRef()); + if (d && d->getBufferSize()) std::fprintf(stderr, "%s\n", (const char*)d->getBufferPointer()); + if (SLANG_FAILED(r)) throw std::runtime_error("slang: compose failed"); + } + { + Slang::ComPtr d; + SlangResult r = composed->link(linked.writeRef(), d.writeRef()); + if (d && d->getBufferSize()) std::fprintf(stderr, "%s\n", (const char*)d->getBufferPointer()); + if (SLANG_FAILED(r)) throw std::runtime_error("slang: link failed"); + } + + Slang::ComPtr spirv; + { + Slang::ComPtr d; + SlangResult r = linked->getEntryPointCode(0, 0, spirv.writeRef(), d.writeRef()); + if (d && d->getBufferSize()) std::fprintf(stderr, "%s\n", (const char*)d->getBufferPointer()); + if (SLANG_FAILED(r)) throw std::runtime_error("slang: getEntryPointCode failed"); + } + + size_t n = spirv->getBufferSize(); + auto* p = static_cast(spirv->getBufferPointer()); + std::vector out((n + 3) / 4); + std::memcpy(out.data(), p, n); + return out; +} + +ComputeBindings createBindings(VulkanContext& ctx, const ComputeProgram& prog, + const std::vector& readonlyBuffers, + const std::vector& readwriteBuffers) { + if (readonlyBuffers.size() != prog.numRO || readwriteBuffers.size() != prog.numRW) + throw std::runtime_error("buffer count != program layout"); + + vk::DescriptorSetAllocateInfo ai(ctx.descriptorPool, 1, &prog.descriptorLayout); + ComputeBindings b{}; + b.set = ctx.device.allocateDescriptorSets(ai)[0]; + + std::vector infos; + infos.reserve(prog.numRO + prog.numRW); + for (auto* x : readonlyBuffers) infos.emplace_back(x->buffer, 0, x->size); + for (auto* x : readwriteBuffers) infos.emplace_back(x->buffer, 0, x->size); + + std::vector writes; + writes.reserve(infos.size()); + for (uint32_t i = 0; i < infos.size(); ++i) + writes.emplace_back(b.set, i, 0, 1, vk::DescriptorType::eStorageBuffer, nullptr, &infos[i]); + + ctx.device.updateDescriptorSets(writes, {}); + return b; +} + +static ComputeProgram createComputeProgram(const std::vector& spirv, + uint32_t roCount, uint32_t rwCount, uint32_t pushConstantSize) { + auto& ctx = getVulkanContext(); + + ComputeProgram prog; + prog.numRO = roCount; prog.numRW = rwCount; + prog.pushConstantSize = pushConstantSize; + + if (prog.pushConstantSize) { + auto limits = ctx.physicalDevice.getProperties().limits; + if (prog.pushConstantSize > limits.maxPushConstantsSize) { + throw std::runtime_error("push constant block exceeds device limit"); + } + } + + vk::ShaderModuleCreateInfo smci({}, spirv.size() * sizeof(uint32_t), spirv.data()); + prog.shaderModule = ctx.device.createShaderModule(smci); + + std::vector bindings; + bindings.reserve(roCount + rwCount); + for (uint32_t b = 0; b < roCount + rwCount; ++b) + bindings.emplace_back(b, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute); + + vk::DescriptorSetLayoutCreateInfo dsInfo({}, bindings.size(), bindings.data()); + prog.descriptorLayout = ctx.device.createDescriptorSetLayout(dsInfo); + + vk::PushConstantRange pushRange(vk::ShaderStageFlagBits::eCompute, 0, prog.pushConstantSize); + auto pushPtr = prog.pushConstantSize ? &pushRange : nullptr; + uint32_t pushCount = prog.pushConstantSize ? 1u : 0u; + + vk::PipelineLayoutCreateInfo plInfo({}, 1, &prog.descriptorLayout, pushCount, pushPtr); + prog.pipelineLayout = ctx.device.createPipelineLayout(plInfo); + + vk::PipelineShaderStageCreateInfo stageInfo({}, vk::ShaderStageFlagBits::eCompute, prog.shaderModule, "main"); + vk::ComputePipelineCreateInfo cpInfo({}, stageInfo, prog.pipelineLayout); + prog.pipeline = ctx.device.createComputePipeline({}, cpInfo).value; + + return prog; +} + +ComputeProgram createComputeProgramFromSlang(const std::string& moduleName, + const std::string& source, const std::string& entry, + uint32_t roCount, uint32_t rwCount, uint32_t pushConstantSize) { + auto spirv = compileSlangToSpirv(moduleName.c_str(), source.c_str(), entry.c_str(), "spirv_1_5"); + return createComputeProgram(spirv, roCount, rwCount, pushConstantSize); +} + +void destroyComputeProgram(ComputeProgram& prog) { + auto& ctx = getVulkanContext(); + invalidateDescriptorCacheForLayout(ctx, prog.descriptorLayout); + ctx.device.destroyPipeline(prog.pipeline); + ctx.device.destroyPipelineLayout(prog.pipelineLayout); + ctx.device.destroyDescriptorSetLayout(prog.descriptorLayout); + ctx.device.destroyShaderModule(prog.shaderModule); + prog = {}; +} + +void runProgram(const ComputeProgram& prog, + const std::vector& readonlyBuffers, + const std::vector& readwriteBuffers, + uint32_t groupCount, + const void* pushConstants, + size_t pushConstantSize) { + auto& ctx = getVulkanContext(); + auto set = getOrCreateSet(ctx, prog, readonlyBuffers, readwriteBuffers); + + vk::CommandBufferAllocateInfo ai(ctx.commandPool, vk::CommandBufferLevel::ePrimary, 1); + auto cmd = ctx.device.allocateCommandBuffers(ai)[0]; + + cmd.begin(vk::CommandBufferBeginInfo{}); + cmd.bindPipeline(vk::PipelineBindPoint::eCompute, prog.pipeline); + cmd.bindDescriptorSets(vk::PipelineBindPoint::eCompute, prog.pipelineLayout, 0, set, {}); + + if (prog.pushConstantSize) { + if (!pushConstants) { + throw std::runtime_error("push constant payload missing"); + } + if (pushConstantSize != prog.pushConstantSize) { + throw std::runtime_error("push constant payload size mismatch"); + } + cmd.pushConstants(prog.pipelineLayout, vk::ShaderStageFlagBits::eCompute, 0, + prog.pushConstantSize, pushConstants); + } else if (pushConstantSize != 0) { + throw std::runtime_error("push constant payload provided but pipeline has none"); + } + + cmd.dispatch(groupCount, 1, 1); + cmd.end(); + + vk::Fence fence = ctx.device.createFence({}); + ctx.computeQueue.submit(vk::SubmitInfo(0, nullptr, 0, 1, &cmd), fence); + [[maybe_unused]] auto rWait = ctx.device.waitForFences(fence, VK_TRUE, UINT64_MAX); + ctx.device.destroyFence(fence); + ctx.device.freeCommandBuffers(ctx.commandPool, cmd); +} + diff --git a/TensorFrost/Backend/src/Window.cpp b/TensorFrost/Backend/src/Window.cpp new file mode 100644 index 00000000..0cf981ae --- /dev/null +++ b/TensorFrost/Backend/src/Window.cpp @@ -0,0 +1,597 @@ +#include "Backend/Vulkan.h" +#include "Backend/Window.h" + +#include +#include +#include + +#include +#include +#include +#include +#include + +static std::unordered_map gScrollContexts; +static std::mutex gScrollMutex; + +namespace { +std::unique_ptr gWindow; +std::mutex gWindowMutex; + +void CheckVkResult(VkResult err) { + if (err == VK_SUCCESS) return; + throw std::runtime_error("ImGui Vulkan backend error: " + std::to_string(err)); +} + +void ScrollCallback(GLFWwindow* wnd, double xoffset, double yoffset) { + WindowContext* ctx = nullptr; + GLFWscrollfun prev = nullptr; + { + std::scoped_lock lock(gScrollMutex); + auto it = gScrollContexts.find(wnd); + if (it != gScrollContexts.end()) { + ctx = it->second; + prev = ctx ? ctx->prevScrollCallback : nullptr; + } + } + + if (ctx) { + ctx->scrollDeltaX += xoffset; + ctx->scrollDeltaY += yoffset; + } + + if (prev && prev != ScrollCallback) { + prev(wnd, xoffset, yoffset); + } +} + +void EnsureFramebuffers(WindowContext& ctx) { + if (ctx.framebuffers.size() == ctx.images.size() && !ctx.framebuffers.empty()) return; + + if (!ctx.renderPass) { + vk::AttachmentDescription colorAttachment{}; + colorAttachment.format = ctx.format; + colorAttachment.samples = vk::SampleCountFlagBits::e1; + colorAttachment.loadOp = vk::AttachmentLoadOp::eLoad; + colorAttachment.storeOp = vk::AttachmentStoreOp::eStore; + colorAttachment.stencilLoadOp = vk::AttachmentLoadOp::eDontCare; + colorAttachment.stencilStoreOp = vk::AttachmentStoreOp::eDontCare; + colorAttachment.initialLayout = vk::ImageLayout::eColorAttachmentOptimal; + colorAttachment.finalLayout = vk::ImageLayout::ePresentSrcKHR; + + vk::AttachmentReference colorAttachmentRef{0, vk::ImageLayout::eColorAttachmentOptimal}; + + vk::SubpassDescription subpass{}; + subpass.pipelineBindPoint = vk::PipelineBindPoint::eGraphics; + subpass.colorAttachmentCount = 1; + subpass.pColorAttachments = &colorAttachmentRef; + + vk::SubpassDependency dependency{}; + dependency.srcSubpass = VK_SUBPASS_EXTERNAL; + dependency.dstSubpass = 0; + dependency.srcStageMask = vk::PipelineStageFlagBits::eColorAttachmentOutput; + dependency.dstStageMask = vk::PipelineStageFlagBits::eColorAttachmentOutput; + dependency.srcAccessMask = {}; + dependency.dstAccessMask = vk::AccessFlagBits::eColorAttachmentWrite; + + vk::RenderPassCreateInfo rpci{}; + rpci.attachmentCount = 1; + rpci.pAttachments = &colorAttachment; + rpci.subpassCount = 1; + rpci.pSubpasses = &subpass; + rpci.dependencyCount = 1; + rpci.pDependencies = &dependency; + + ctx.renderPass = ctx.device.createRenderPass(rpci); + } + + if (ctx.imageViews.size() != ctx.images.size()) { + for (auto view : ctx.imageViews) ctx.device.destroyImageView(view); + ctx.imageViews.clear(); + ctx.imageViews.reserve(ctx.images.size()); + for (auto image : ctx.images) { + vk::ImageViewCreateInfo viewInfo{}; + viewInfo.image = image; + viewInfo.viewType = vk::ImageViewType::e2D; + viewInfo.format = ctx.format; + viewInfo.components = vk::ComponentMapping(); + viewInfo.subresourceRange = {vk::ImageAspectFlagBits::eColor, 0, 1, 0, 1}; + ctx.imageViews.push_back(ctx.device.createImageView(viewInfo)); + } + } + + for (auto fb : ctx.framebuffers) ctx.device.destroyFramebuffer(fb); + ctx.framebuffers.clear(); + ctx.framebuffers.reserve(ctx.imageViews.size()); + for (auto view : ctx.imageViews) { + vk::FramebufferCreateInfo fbci{}; + fbci.renderPass = ctx.renderPass; + fbci.attachmentCount = 1; + fbci.pAttachments = &view; + fbci.width = ctx.extent.width; + fbci.height = ctx.extent.height; + fbci.layers = 1; + ctx.framebuffers.push_back(ctx.device.createFramebuffer(fbci)); + } +} + +void EnsureImGui(WindowContext& ctx, uint32_t imageCount) { + if (ctx.imguiContext) return; + + EnsureFramebuffers(ctx); + + std::array poolSizes = { + vk::DescriptorPoolSize{vk::DescriptorType::eSampler, 1000}, + vk::DescriptorPoolSize{vk::DescriptorType::eCombinedImageSampler, 1000}, + vk::DescriptorPoolSize{vk::DescriptorType::eSampledImage, 1000}, + vk::DescriptorPoolSize{vk::DescriptorType::eStorageImage, 1000}, + vk::DescriptorPoolSize{vk::DescriptorType::eUniformTexelBuffer, 1000}, + vk::DescriptorPoolSize{vk::DescriptorType::eStorageTexelBuffer, 1000}, + vk::DescriptorPoolSize{vk::DescriptorType::eUniformBuffer, 1000}, + vk::DescriptorPoolSize{vk::DescriptorType::eStorageBuffer, 1000}, + vk::DescriptorPoolSize{vk::DescriptorType::eUniformBufferDynamic, 1000}, + vk::DescriptorPoolSize{vk::DescriptorType::eStorageBufferDynamic, 1000}, + vk::DescriptorPoolSize{vk::DescriptorType::eInputAttachment, 1000} + }; + + vk::DescriptorPoolCreateInfo poolInfo(vk::DescriptorPoolCreateFlagBits::eFreeDescriptorSet, + 1000 * static_cast(poolSizes.size()), + static_cast(poolSizes.size()), + poolSizes.data()); + ctx.imguiPool = ctx.device.createDescriptorPool(poolInfo); + + IMGUI_CHECKVERSION(); + ctx.imguiContext = ImGui::CreateContext(); + ImGui::SetCurrentContext(ctx.imguiContext); + ImGui::StyleColorsDark(); + + ImGui_ImplGlfw_InitForVulkan(ctx.wnd, true); + + ImGui_ImplVulkan_InitInfo initInfo{}; + initInfo.Instance = ctx.instance; + initInfo.PhysicalDevice = ctx.phys; + initInfo.Device = ctx.device; + initInfo.QueueFamily = ctx.presentFam; + initInfo.Queue = ctx.queue; + initInfo.Subpass = 0; + initInfo.MinImageCount = imageCount; + initInfo.ImageCount = imageCount; + initInfo.MSAASamples = VK_SAMPLE_COUNT_1_BIT; + initInfo.Allocator = nullptr; + initInfo.PipelineCache = VK_NULL_HANDLE; + initInfo.DescriptorPool = ctx.imguiPool; + initInfo.CheckVkResultFn = CheckVkResult; + initInfo.RenderPass = static_cast(ctx.renderPass); + + if (!ImGui_ImplVulkan_Init(&initInfo)) { + throw std::runtime_error("ImGui_ImplVulkan_Init failed"); + } + + if (!ImGui_ImplVulkan_CreateFontsTexture()) { + throw std::runtime_error("ImGui_ImplVulkan_CreateFontsTexture failed"); + } +} + +void StartImGuiFrame(WindowContext& ctx) { + if (!ctx.imguiContext || ctx.imguiFrameActive) return; + ImGui::SetCurrentContext(ctx.imguiContext); + ImGui_ImplVulkan_NewFrame(); + ImGui_ImplGlfw_NewFrame(); + ImGui::NewFrame(); + ctx.imguiFrameActive = true; +} + +void DestroySwapchainViews(WindowContext& ctx) { + for (auto fb : ctx.framebuffers) ctx.device.destroyFramebuffer(fb); + ctx.framebuffers.clear(); + for (auto view : ctx.imageViews) ctx.device.destroyImageView(view); + ctx.imageViews.clear(); +} + +vk::SurfaceFormatKHR SelectSurfaceFormat(const WindowContext& ctx, + const std::vector& formats) { + if (formats.empty()) { + return {vk::Format::eB8G8R8A8Unorm, vk::ColorSpaceKHR::eSrgbNonlinear}; + } + for (const auto& fmt : formats) { + if (fmt.format == ctx.format) { + return fmt; + } + } + return formats.front(); +} + +void RecreateSwapchain(WindowContext& ctx, vk::Extent2D desiredExtent) { + if (!ctx.wnd) return; + + ctx.device.waitIdle(); + + DestroySwapchainViews(ctx); + + auto caps = ctx.phys.getSurfaceCapabilitiesKHR(ctx.surface); + auto fmts = ctx.phys.getSurfaceFormatsKHR(ctx.surface); + if (!(caps.supportedUsageFlags & vk::ImageUsageFlagBits::eTransferDst)) { + throw std::runtime_error("swapchain missing TRANSFER_DST"); + } + + vk::SurfaceFormatKHR surfaceFormat = SelectSurfaceFormat(ctx, fmts); + + vk::Extent2D extent{}; + if (caps.currentExtent.width != UINT32_MAX) { + extent = caps.currentExtent; + } else { + extent.width = static_cast(std::clamp(static_cast(desiredExtent.width), + static_cast(caps.minImageExtent.width), + static_cast(caps.maxImageExtent.width))); + extent.height = static_cast(std::clamp(static_cast(desiredExtent.height), + static_cast(caps.minImageExtent.height), + static_cast(caps.maxImageExtent.height))); + } + + if (extent.width == 0 || extent.height == 0) { + ctx.extent = extent; + return; + } + + uint32_t imageCount = std::max(caps.minImageCount, 2u); + if (caps.maxImageCount) imageCount = std::min(imageCount, caps.maxImageCount); + + vk::SwapchainCreateInfoKHR createInfo{}; + createInfo.surface = ctx.surface; + createInfo.minImageCount = imageCount; + createInfo.imageFormat = surfaceFormat.format; + createInfo.imageColorSpace = surfaceFormat.colorSpace; + createInfo.imageExtent = extent; + createInfo.imageArrayLayers = 1; + createInfo.imageUsage = vk::ImageUsageFlagBits::eTransferDst; + createInfo.imageSharingMode = vk::SharingMode::eExclusive; + createInfo.preTransform = caps.currentTransform; + createInfo.compositeAlpha = vk::CompositeAlphaFlagBitsKHR::eOpaque; + createInfo.presentMode = vk::PresentModeKHR::eFifo; + createInfo.clipped = VK_TRUE; + createInfo.oldSwapchain = ctx.swapchain; + + vk::SwapchainKHR newSwapchain = ctx.device.createSwapchainKHR(createInfo); + if (ctx.swapchain) { + ctx.device.destroySwapchainKHR(ctx.swapchain); + } + ctx.swapchain = newSwapchain; + ctx.images = ctx.device.getSwapchainImagesKHR(ctx.swapchain); + ctx.extent = extent; + ctx.format = surfaceFormat.format; + + EnsureFramebuffers(ctx); + + if (ctx.imguiContext) { + ImGui_ImplVulkan_SetMinImageCount(imageCount); + } +} +} // namespace + +namespace TFWindowDetail { + +void RegisterScrollContext(GLFWwindow* wnd, WindowContext* ctx) { + if (!wnd) return; + std::scoped_lock lock(gScrollMutex); + gScrollContexts[wnd] = ctx; +} + +void UnregisterScrollContext(GLFWwindow* wnd) { + if (!wnd) return; + std::scoped_lock lock(gScrollMutex); + gScrollContexts.erase(wnd); +} + +} // namespace TFWindowDetail + +void AttachWindowCallbacks(WindowContext& ctx) { + if (!ctx.wnd) return; + TFWindowDetail::RegisterScrollContext(ctx.wnd, &ctx); + GLFWscrollfun prev = glfwSetScrollCallback(ctx.wnd, ScrollCallback); + if (prev && prev != ScrollCallback) { + ctx.prevScrollCallback = prev; + } +} + +void ReleaseImGui(WindowContext& ctx) { + if (!ctx.imguiContext) return; + ImGui::SetCurrentContext(ctx.imguiContext); + ImGui_ImplVulkan_Shutdown(); + ImGui_ImplGlfw_Shutdown(); + ImGui::DestroyContext(ctx.imguiContext); + ctx.imguiContext = nullptr; + ctx.imguiFrameActive = false; +} + +WindowContext createWindow(int width, int height, const char* title) { + auto& vctx = getVulkanContext(); + if (!glfwInit()) throw std::runtime_error("glfwInit"); + glfwWindowHint(GLFW_CLIENT_API, GLFW_NO_API); + + WindowContext ctx{}; + ctx.wnd = glfwCreateWindow(width, height, title, nullptr, nullptr); + if (!ctx.wnd) throw std::runtime_error("glfwCreateWindow"); + + ctx.instance = vctx.instance; + ctx.phys = vctx.physicalDevice; + ctx.device = vctx.device; + + // create surface on the shared instance + VkSurfaceKHR raw{}; + if (glfwCreateWindowSurface(ctx.instance, ctx.wnd, nullptr, &raw) != VK_SUCCESS) + throw std::runtime_error("glfw surface"); + ctx.surface = raw; + + // pick a present-capable family; prefer the graphics family we provisioned + if (ctx.phys.getSurfaceSupportKHR(vctx.graphicsFamilyIndex, ctx.surface)) { + ctx.presentFam = vctx.graphicsFamilyIndex; + ctx.queue = vctx.presentQueue; + } else if (ctx.phys.getSurfaceSupportKHR(vctx.queueFamilyIndex, ctx.surface)) { + ctx.presentFam = vctx.queueFamilyIndex; + ctx.queue = vctx.computeQueue; // uncommon but legal if it supports present + } else { + throw std::runtime_error("device queues do not support presenting to this surface"); + } + + auto caps = ctx.phys.getSurfaceCapabilitiesKHR(ctx.surface); + auto fmts = ctx.phys.getSurfaceFormatsKHR(ctx.surface); + vk::SurfaceFormatKHR sf = fmts.empty() + ? vk::SurfaceFormatKHR{vk::Format::eB8G8R8A8Unorm, vk::ColorSpaceKHR::eSrgbNonlinear} + : fmts[0]; + if (!(caps.supportedUsageFlags & vk::ImageUsageFlagBits::eTransferDst)) + throw std::runtime_error("swapchain missing TRANSFER_DST"); + + int fbw, fbh; glfwGetFramebufferSize(ctx.wnd, &fbw, &fbh); + ctx.extent = vk::Extent2D{ + uint32_t(std::clamp(fbw, int(caps.minImageExtent.width), int(caps.maxImageExtent.width))), + uint32_t(std::clamp(fbh, int(caps.minImageExtent.height), int(caps.maxImageExtent.height))) + }; + ctx.format = sf.format; + + uint32_t imageCount = std::max(caps.minImageCount, 2u); + if (caps.maxImageCount) imageCount = std::min(imageCount, caps.maxImageCount); + + ctx.swapchain = ctx.device.createSwapchainKHR({ + {}, ctx.surface, imageCount, sf.format, sf.colorSpace, ctx.extent, 1, + vk::ImageUsageFlagBits::eTransferDst, + vk::SharingMode::eExclusive, 0, nullptr, + caps.currentTransform, vk::CompositeAlphaFlagBitsKHR::eOpaque, + vk::PresentModeKHR::eFifo, VK_TRUE, {} + }); + ctx.images = ctx.device.getSwapchainImagesKHR(ctx.swapchain); + + ctx.pool = ctx.device.createCommandPool({vk::CommandPoolCreateFlagBits::eResetCommandBuffer, ctx.presentFam}); + ctx.cmd = ctx.device.allocateCommandBuffers({ctx.pool, vk::CommandBufferLevel::ePrimary, 1})[0]; + ctx.semImage = ctx.device.createSemaphore({}); + ctx.semDone = ctx.device.createSemaphore({}); + ctx.fence = ctx.device.createFence({}); + + EnsureImGui(ctx, imageCount); + StartImGuiFrame(ctx); + + return ctx; +} + +bool windowOpen(const WindowContext &ctx) { + return ctx.wnd && !glfwWindowShouldClose(ctx.wnd); +} + +void drawBuffer(WindowContext &ctx, vk::Buffer src, uint32_t width, uint32_t height, vk::DeviceSize offset) { + if (!ctx.wnd) return; + + glfwPollEvents(); + + uint32_t idx = 0; + while (true) { + int fbw = 0; + int fbh = 0; + glfwGetFramebufferSize(ctx.wnd, &fbw, &fbh); + if (fbw <= 0 || fbh <= 0) { + if (ctx.imguiContext && ctx.imguiFrameActive) { + ImGui::SetCurrentContext(ctx.imguiContext); + ImGui::Render(); + ctx.imguiFrameActive = false; + } + return; + } + + vk::Extent2D desiredExtent{static_cast(fbw), static_cast(fbh)}; + if (desiredExtent.width != ctx.extent.width || desiredExtent.height != ctx.extent.height) { + RecreateSwapchain(ctx, desiredExtent); + continue; + } + + auto acq = ctx.device.acquireNextImageKHR(ctx.swapchain, UINT64_MAX, ctx.semImage, {}); + if (acq.result == vk::Result::eErrorOutOfDateKHR || acq.result == vk::Result::eSuboptimalKHR) { + RecreateSwapchain(ctx, desiredExtent); + continue; + } + + idx = acq.value; + break; + } + + ctx.cmd.reset({}); + ctx.cmd.begin({vk::CommandBufferUsageFlagBits::eOneTimeSubmit}); + + vk::ImageSubresourceRange range{vk::ImageAspectFlagBits::eColor, 0, 1, 0, 1}; + + vk::ImageMemoryBarrier toTransfer({}, vk::AccessFlagBits::eTransferWrite, + vk::ImageLayout::eUndefined, vk::ImageLayout::eTransferDstOptimal, + VK_QUEUE_FAMILY_IGNORED, VK_QUEUE_FAMILY_IGNORED, + ctx.images[idx], range); + ctx.cmd.pipelineBarrier(vk::PipelineStageFlagBits::eTopOfPipe, + vk::PipelineStageFlagBits::eTransfer, + {}, nullptr, nullptr, toTransfer); + + uint32_t copyWidth = std::min(width, ctx.extent.width); + uint32_t copyHeight = std::min(height, ctx.extent.height); + + bool performedTransfer = false; + + if (!src || copyWidth != ctx.extent.width || copyHeight != ctx.extent.height) { + vk::ClearColorValue clearColor(std::array{0.f, 0.f, 0.f, 1.f}); + std::array ranges{range}; + ctx.cmd.clearColorImage(ctx.images[idx], vk::ImageLayout::eTransferDstOptimal, clearColor, ranges); + performedTransfer = true; + } + + if (src && copyWidth > 0 && copyHeight > 0) { + vk::BufferImageCopy copy{}; + copy.bufferOffset = offset; + copy.imageSubresource = {vk::ImageAspectFlagBits::eColor, 0, 0, 1}; + copy.imageExtent = vk::Extent3D{copyWidth, copyHeight, 1}; + ctx.cmd.copyBufferToImage(src, ctx.images[idx], vk::ImageLayout::eTransferDstOptimal, 1, ©); + performedTransfer = true; + } + + vk::ImageMemoryBarrier toColor(performedTransfer ? vk::AccessFlagBits::eTransferWrite : vk::AccessFlags{}, + vk::AccessFlagBits::eColorAttachmentWrite, + performedTransfer ? vk::ImageLayout::eTransferDstOptimal : vk::ImageLayout::eUndefined, + vk::ImageLayout::eColorAttachmentOptimal, + VK_QUEUE_FAMILY_IGNORED, VK_QUEUE_FAMILY_IGNORED, + ctx.images[idx], range); + ctx.cmd.pipelineBarrier(performedTransfer ? vk::PipelineStageFlagBits::eTransfer : vk::PipelineStageFlagBits::eTopOfPipe, + vk::PipelineStageFlagBits::eColorAttachmentOutput, + {}, nullptr, nullptr, toColor); + + EnsureFramebuffers(ctx); + + vk::RenderPassBeginInfo rpBegin{}; + rpBegin.renderPass = ctx.renderPass; + rpBegin.framebuffer = ctx.framebuffers[idx]; + rpBegin.renderArea.offset = vk::Offset2D{0, 0}; + rpBegin.renderArea.extent = ctx.extent; + vk::ClearValue clearValue{}; + rpBegin.clearValueCount = 1; + rpBegin.pClearValues = &clearValue; + + ctx.cmd.beginRenderPass(rpBegin, vk::SubpassContents::eInline); + + if (ctx.imguiContext && ctx.imguiFrameActive) { + ImGui::SetCurrentContext(ctx.imguiContext); + ImGui::Render(); + ImDrawData* drawData = ImGui::GetDrawData(); + if (drawData && drawData->CmdListsCount > 0) { + ImGui_ImplVulkan_RenderDrawData(drawData, static_cast(ctx.cmd)); + } + ctx.imguiFrameActive = false; + } + + ctx.cmd.endRenderPass(); + + vk::ImageMemoryBarrier toPresent(vk::AccessFlagBits::eColorAttachmentWrite, {}, + vk::ImageLayout::eColorAttachmentOptimal, vk::ImageLayout::ePresentSrcKHR, + VK_QUEUE_FAMILY_IGNORED, VK_QUEUE_FAMILY_IGNORED, + ctx.images[idx], range); + ctx.cmd.pipelineBarrier(vk::PipelineStageFlagBits::eColorAttachmentOutput, + vk::PipelineStageFlagBits::eBottomOfPipe, + {}, nullptr, nullptr, toPresent); + + ctx.cmd.end(); + + vk::PipelineStageFlags waitStage = vk::PipelineStageFlagBits::eColorAttachmentOutput; + + (void)ctx.device.resetFences(ctx.fence); + ctx.queue.submit({vk::SubmitInfo(1, &ctx.semImage, &waitStage, 1, &ctx.cmd, 1, &ctx.semDone)}, ctx.fence); + (void)ctx.device.waitForFences(ctx.fence, VK_TRUE, UINT64_MAX); + + try { + (void)ctx.queue.presentKHR({1, &ctx.semDone, 1, &ctx.swapchain, &idx}); + } catch (const vk::OutOfDateKHRError&) { + // ignore + } + + StartImGuiFrame(ctx); +} + +void drawBuffer(WindowContext &ctx, const Buffer &b, uint32_t w, uint32_t h, size_t offset) { + // optional sanity check: + if (offset + size_t(w)*size_t(h)*4 > b.size) throw std::out_of_range("buffer too small"); + drawBuffer(ctx, b.buffer, w, h, offset); +} + +WindowContext* GetWindow() { + std::scoped_lock lock(gWindowMutex); + return gWindow.get(); +} + +WindowContext& RequireWindow() { + auto* wnd = GetWindow(); + if (!wnd) throw std::runtime_error("Window not created. Call ShowWindow() first."); + return *wnd; +} + +ImGuiContext* GetImGuiContext() { + auto* wnd = GetWindow(); + return wnd ? wnd->imguiContext : nullptr; +} + +void EnsureImGuiFrame(WindowContext& ctx) { + StartImGuiFrame(ctx); +} + +void ShowWindow(int width, int height, const char* title) { + std::scoped_lock lock(gWindowMutex); + if (gWindow && windowOpen(*gWindow)) return; + gWindow = std::make_unique(createWindow(width, height, title)); + AttachWindowCallbacks(*gWindow); +} + +void HideWindow() { + std::scoped_lock lock(gWindowMutex); + gWindow.reset(); +} + +void RenderFrame(const Buffer* buffer, uint32_t width, uint32_t height, size_t offset) { + auto& ctx = RequireWindow(); + EnsureImGuiFrame(ctx); + + uint32_t w = width ? width : ctx.extent.width; + uint32_t h = height ? height : ctx.extent.height; + + if (buffer) { + drawBuffer(ctx, buffer->buffer, w, h, offset); + } else { + drawBuffer(ctx, vk::Buffer{}, w, h, offset); + } +} + +void RenderFrame(const Buffer* buffer) { + auto& ctx = RequireWindow(); + RenderFrame(buffer, ctx.extent.width, ctx.extent.height, 0); +} + +bool WindowShouldClose() { + std::scoped_lock lock(gWindowMutex); + if (!gWindow) return true; + return !windowOpen(*gWindow); +} + +std::pair GetMousePosition() { + std::scoped_lock lock(gWindowMutex); + if (!gWindow || !gWindow->wnd) return {0.0, 0.0}; + double x{}, y{}; + glfwGetCursorPos(gWindow->wnd, &x, &y); + return {x, y}; +} + +std::pair GetWindowSize() { + std::scoped_lock lock(gWindowMutex); + if (!gWindow || !gWindow->wnd) return {0, 0}; + int w{}, h{}; + glfwGetWindowSize(gWindow->wnd, &w, &h); + return {w, h}; +} + +bool IsMouseButtonPressed(int button) { + std::scoped_lock lock(gWindowMutex); + if (!gWindow || !gWindow->wnd) return false; + return glfwGetMouseButton(gWindow->wnd, button) == GLFW_PRESS; +} + +bool IsKeyPressed(int key) { + std::scoped_lock lock(gWindowMutex); + if (!gWindow || !gWindow->wnd) return false; + return glfwGetKey(gWindow->wnd, key) == GLFW_PRESS; +} diff --git a/TensorFrost/CMakeLists.txt b/TensorFrost/CMakeLists.txt index db1ea482..35024978 100644 --- a/TensorFrost/CMakeLists.txt +++ b/TensorFrost/CMakeLists.txt @@ -1,8 +1,39 @@ -file(GLOB_RECURSE TENSORFROST_SOURCE_LIST CONFIGURE_DEPENDS *.cpp) -file(GLOB_RECURSE TENSORFROST_HEADER_LIST CONFIGURE_DEPENDS *.h *.hpp) +set(TF_INC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include) +set(TF_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src) -pybind11_add_module(TensorFrost ${TENSORFROST_SOURCE_LIST} ${TENSORFROST_HEADER_LIST}) +find_library(SLANG_LIB_RELEASE NAMES slang PATHS $ENV{VULKAN_SDK}/Lib) +find_library(SLANG_LIB_DEBUG NAMES slangd slang PATHS $ENV{VULKAN_SDK}/Lib) +if (NOT SLANG_LIB_RELEASE) + message(FATAL_ERROR "slang.lib not found in %VULKAN_SDK%/Lib") +endif() + +add_subdirectory(Compiler) +add_subdirectory(Backend) + +file(GLOB_RECURSE TENSORFROST_SOURCE_LIST CONFIGURE_DEPENDS + ${TF_SRC_DIR}/*.cpp) + +file(GLOB_RECURSE TENSORFROST_HEADER_LIST CONFIGURE_DEPENDS + ${TF_INC_DIR}/*.h ${TF_INC_DIR}/*.hpp) + +set(PYBIND_MODULE ${CMAKE_CURRENT_SOURCE_DIR}/PybindModule.cpp) + +# ---- Build pybind11 module ---- +pybind11_add_module(TensorFrost + ${PYBIND_MODULE} + ${TENSORFROST_SOURCE_LIST} + ${TENSORFROST_HEADER_LIST} +) + +# ---- Include directories ---- +target_include_directories(TensorFrost + PRIVATE + ${TF_INC_DIR} + ${Python3_INCLUDE_DIRS} +) + +# ---- Fix for macOS RPATH issues ---- if(APPLE) set(CMAKE_MACOSX_RPATH ON) set_target_properties(TensorFrost PROPERTIES @@ -11,49 +42,60 @@ if(APPLE) ) endif() -# Add GLFW -target_link_libraries(TensorFrost PRIVATE glfw) - -glad_add_library(glad_gl_core_46 SHARED API gl:core=4.6) -target_link_libraries(TensorFrost PRIVATE glad_gl_core_46) - -glad_add_library(glad_vulkan_12 REPRODUCIBLE LOADER API vulkan=1.2) -target_link_libraries(TensorFrost PRIVATE glad_vulkan_12) +target_link_libraries(TensorFrost PRIVATE + TensorFrostCompiler + TensorFrostBackend) -target_include_directories(TensorFrost PRIVATE ${Python3_INCLUDE_DIRS}) - -target_include_directories(TensorFrost PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) - -#add imgui headers -target_include_directories(TensorFrost PRIVATE ${CMAKE_SOURCE_DIR}/external/imgui) -target_include_directories(TensorFrost PRIVATE ${CMAKE_SOURCE_DIR}/external/imgui/backends) - -#add imgui sources -file(GLOB IMGUI_SOURCE_LIST ${CMAKE_SOURCE_DIR}/external/imgui/*.cpp) -file(GLOB IMGUI_BACKEND_SOURCE_LIST ${CMAKE_SOURCE_DIR}/external/imgui/backends/imgui_impl_glfw.cpp ${CMAKE_SOURCE_DIR}/external/imgui/backends/imgui_impl_opengl3.cpp) +if (MSVC) + target_compile_options(TensorFrost PRIVATE /wd4804 /wd4805 /wd4018) +endif() +# ---- ImGui ---- +target_include_directories(TensorFrost PRIVATE + ${CMAKE_SOURCE_DIR}/external/imgui + ${CMAKE_SOURCE_DIR}/external/imgui/backends +) +file(GLOB IMGUI_SOURCE_LIST + ${CMAKE_SOURCE_DIR}/external/imgui/*.cpp) +file(GLOB IMGUI_BACKEND_SOURCE_LIST + ${CMAKE_SOURCE_DIR}/external/imgui/backends/imgui_impl_glfw.cpp + ${CMAKE_SOURCE_DIR}/external/imgui/backends/imgui_impl_vulkan.cpp) target_sources(TensorFrost PRIVATE ${IMGUI_SOURCE_LIST} ${IMGUI_BACKEND_SOURCE_LIST}) -#add renderdoc headers +# ---- RenderDoc headers ---- target_include_directories(TensorFrost PRIVATE ${CMAKE_SOURCE_DIR}/external/renderdoc) +# ---- convenience / IDE niceties ---- +source_group(TREE ${TF_SRC_DIR} PREFIX "Source Files" + FILES ${TENSORFROST_SOURCE_LIST}) -add_custom_target(install_python_package ALL - DEPENDS TensorFrost -) +source_group(TREE ${CMAKE_CURRENT_SOURCE_DIR} PREFIX "Source Files" + FILES ${PYBIND_MODULE}) -set(DEBUG_PYTHON_SCRIPT "${CMAKE_SOURCE_DIR}/examples/debug.py") +source_group(TREE ${TF_INC_DIR} PREFIX "Header Files" + FILES ${TENSORFROST_HEADER_LIST}) -set_target_properties(TensorFrost PROPERTIES - VS_DEBUGGER_COMMAND "${Python3_EXECUTABLE}" - VS_DEBUGGER_COMMAND_ARGUMENTS "${DEBUG_PYTHON_SCRIPT}" - VS_DEBUGGER_WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" -) +# ---- VS/debug helpers, profiling flags, etc. (unchanged) ---- +add_custom_target(install_python_package ALL DEPENDS TensorFrost) -# Set /PROFILE for RELWITHDEBINFO +set(DEBUG_PYTHON_SCRIPT "${CMAKE_SOURCE_DIR}/examples/debug.py") set_target_properties(TensorFrost PROPERTIES - LINK_FLAGS_RELWITHDEBINFO "/PROFILE" + VS_DEBUGGER_COMMAND "${Python3_EXECUTABLE}" + VS_DEBUGGER_COMMAND_ARGUMENTS "${DEBUG_PYTHON_SCRIPT}" + VS_DEBUGGER_WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" + LINK_FLAGS_RELWITHDEBINFO "/PROFILE" ) -source_group(TREE ${CMAKE_CURRENT_SOURCE_DIR} PREFIX "Source Files" FILES ${TENSORFROST_SOURCE_LIST} ${TENSORFROST_HEADER_LIST}) - +set(VKSDK_BIN "$ENV{VULKAN_SDK}/Bin") + +add_custom_command(TARGET TensorFrost POST_BUILD + COMMAND ${CMAKE_COMMAND} -E make_directory "$" + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${VKSDK_BIN}/slang$<$:d>.dll" "$/slang$<$:d>.dll" + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${VKSDK_BIN}/slang-glslang$<$:d>.dll" "$/slang-glslang$<$:d>.dll" + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${VKSDK_BIN}/slang-glsl-module$<$:d>.dll" "$/slang-glsl-module$<$:d>.dll" + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${VKSDK_BIN}/glslang$<$:d>.dll" "$/glslang$<$:d>.dll" + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${VKSDK_BIN}/glslang-default-resource-limits$<$:d>.dll" "$/glslang-default-resource-limits$<$:d>.dll" + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${VKSDK_BIN}/SPIRV-Tools-shared$<$:d>.dll" "$/SPIRV-Tools-shared$<$:d>.dll" + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${VKSDK_BIN}/SPIRV$<$:d>.dll" "$/SPIRV$<$:d>.dll" + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${VKSDK_BIN}/SPVRemapper$<$:d>.dll" "$/SPVRemapper$<$:d>.dll" +) \ No newline at end of file diff --git a/TensorFrost/Compiler/CMakeLists.txt b/TensorFrost/Compiler/CMakeLists.txt new file mode 100644 index 00000000..5dccaf30 --- /dev/null +++ b/TensorFrost/Compiler/CMakeLists.txt @@ -0,0 +1,35 @@ +set(TF_COMPILER_INC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include) +set(TF_COMPILER_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src) + +file(GLOB_RECURSE TENSORFROST_COMPILER_SOURCE_LIST CONFIGURE_DEPENDS + ${TF_COMPILER_SRC_DIR}/*.cpp) + +file(GLOB_RECURSE TENSORFROST_COMPILER_HEADER_LIST CONFIGURE_DEPENDS + ${TF_COMPILER_INC_DIR}/*.h + ${TF_COMPILER_INC_DIR}/*.hpp) + +add_library(TensorFrostCompiler STATIC + ${TENSORFROST_COMPILER_SOURCE_LIST} + ${TENSORFROST_COMPILER_HEADER_LIST}) + +target_include_directories(TensorFrostCompiler + PUBLIC + ${TF_COMPILER_INC_DIR}) + +target_compile_features(TensorFrostCompiler PUBLIC cxx_std_20) + +if (MSVC) + target_compile_options(TensorFrostCompiler PRIVATE /wd4804 /wd4805 /wd4018) + target_compile_definitions(TensorFrostCompiler PRIVATE + $<$:_ITERATOR_DEBUG_LEVEL=2> + $<$:_HAS_ITERATOR_DEBUGGING=1> + ) + set_property(TARGET TensorFrostCompiler PROPERTY + MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>DLL") +endif() + +source_group(TREE ${TF_COMPILER_SRC_DIR} PREFIX "Source Files" + FILES ${TENSORFROST_COMPILER_SOURCE_LIST}) + +source_group(TREE ${TF_COMPILER_INC_DIR} PREFIX "Header Files" + FILES ${TENSORFROST_COMPILER_HEADER_LIST}) diff --git a/TensorFrost/Compiler/Graph/Arguments.cpp b/TensorFrost/Compiler/Graph/Arguments.cpp deleted file mode 100644 index d9c91b57..00000000 --- a/TensorFrost/Compiler/Graph/Arguments.cpp +++ /dev/null @@ -1,19 +0,0 @@ -#include "Arguments.h" - -namespace TensorFrost { - -void ArgumentManager::ClearOutputs() { - outputs_.clear(); -} - -const map arg_type_names = { - {ArgType::Input, "Input"}, {ArgType::Index, "Index"}, {ArgType::Shape, "Shape"}, - {ArgType::Memory, "Memory"}, {ArgType::None, "None"}, -}; - -string TypeToString(ArgType type) { - return arg_type_names.at(type); -} - - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Compiler/Graph/Arguments.h b/TensorFrost/Compiler/Graph/Arguments.h deleted file mode 100644 index b7bf5698..00000000 --- a/TensorFrost/Compiler/Graph/Arguments.h +++ /dev/null @@ -1,221 +0,0 @@ -#pragma once - -#include "Compiler/Operations.h" -#include "Utility/Utility.h" -#include - -namespace TensorFrost { - -enum class ArgType { - Input, - Index, - Shape, - Memory, - None, - Count, -}; - -class Tensor; -class Node; -string TypeToString(ArgType type); - -//argument type and index -using ArgID = pair; -//input nodes with argument type and index -using NodeArguments = map; -//argument type and input node -using Arg = pair; -//argument type and input/output node - edge of the graph -using ArgEdge = pair; - -struct HashArgID { - size_t operator()(const ArgID& id) const { - return (size_t)id.first + id.second * (size_t)ArgType::Count; - } -}; - -//set of edges -using ArgEdges = set; - -#define MAX_ARGS_PER_TYPE 8 -#define MAX_ARGS ((int)ArgType::Count * MAX_ARGS_PER_TYPE) - -class ArgumentManager { - Node* node_; - bool add_parenthesis = false; - unordered_map argument_types_; - unordered_map argument_counts_; - unordered_map argument_names_; - unordered_map argument_requires_parenthesis_; - unordered_map inputs_; - ArgEdges outputs_; - - void AddOutput(ArgID id, Node* node) { - outputs_.insert({{id, node_}, node}); - } - - void RemoveOutput(ArgID id, Node* node) { - if (!outputs_.contains({{id, node_}, node})) { - throw std::runtime_error("Output does not exist"); - } - outputs_.erase({{id, node_}, node}); - } - - void UpdateOutputs(); - void ClearOutputs(); -public: - ArgumentManager(Node* node) { - if (node == nullptr) { - throw std::runtime_error("Node is null"); - } - this->node_ = node; - } - - void AddParenthesis(bool add) { - add_parenthesis = add; - } - - const unordered_map& Inputs() const { - return inputs_; - } - - const ArgEdges& Outputs() const { - return outputs_; - } - - unordered_map InputsCopy() const { - return inputs_; - } - - ArgEdges OutputsCopy() const { - return outputs_; - } - - size_t OutputCount() const { - return outputs_.size(); - } - - void AddArgument(ArgID id, Node *node); - void AddArgument(ArgType type, int index, Node *node) { - AddArgument(ArgID(type, index), node); - } - - void UpdateArgument(ArgID id, Node *node); - - void AddArguments(NodeArguments new_args) { - for (auto& [id, node] : new_args) { - AddArgument(id, node); - } - } - - void SetName(ArgID id, string name, bool requires_parenthesis = false) { - argument_names_[id] = name; - argument_requires_parenthesis_[id] = requires_parenthesis; - } - - bool Has(ArgID id) const { - return inputs_.find(id) != inputs_.end(); - } - - bool Has(ArgType type, int index = 0) const { - ArgID id = ArgID(type, index); - return Has(id); - } - - Node* Get(ArgType type, int index = 0) const { - ArgID id = ArgID(type, index); - auto Arg = inputs_.find(id); - if (Arg != inputs_.end()) { - return Arg->second; - } else { - throw std::runtime_error("Argument of type " + TypeToString(type) + " at index " + std::to_string(index) + " not found"); - } - } - - void Remove(ArgID id); - void Remove(ArgType type, int index = 0) { - Remove(ArgID(type, index)); - } - - const Tensor *GetTensor(ArgType type, int index = 0) const; - - const Tensor& operator[](int index) const; - - TFDataFormat Format(ArgType type, int index = 0) const { - ArgID id = ArgID(type, index); - auto Arg = argument_types_.find(id); - if (Arg != argument_types_.end()) { - return Arg->second; - } - else { - throw std::runtime_error("Argument format not found"); - } - } - - int Count(ArgType type) const { - auto Arg = argument_counts_.find(type); - if (Arg != argument_counts_.end()) { - return Arg->second; - } - else { - return 0; - } - } - - bool RequiresParenthesis(ArgID id) const { - auto Arg = argument_requires_parenthesis_.find(id); - if (Arg != argument_requires_parenthesis_.end()) { - return Arg->second; - } - else { - return false; - } - } - - string Name(ArgType type, int index = 0) const { - ArgID id = ArgID(type, index); - auto Arg = argument_names_.find(id); - if (Arg != argument_names_.end()) { - string name = Arg->second; - if (add_parenthesis && RequiresParenthesis(id)) { - name = "(" + name + ")"; - } - return name; - } - else { - throw std::runtime_error("Argument name not found"); - } - } - - NodeArguments GetArguments() const { - NodeArguments arguments; - for (auto& [id, node] : inputs_) { - arguments[id] = node; - } - return arguments; - } - - NodeArguments GetArguments(ArgType type) const { - NodeArguments arguments; - for (auto& [id, node] : inputs_) { - if (id.first == type) { - arguments[id] = node; - } - } - return arguments; - } - - map GetTensors(ArgType type) const; - - vector GetTensorVector(ArgType type) const; - - ~ArgumentManager(); - - bool CannotMoveArgument(ArgID id); - bool CannotCopyArgument(ArgID id); - bool IsChangingInput(ArgID arg); - - void RemoveArguments(ArgType arg); -}; - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Compiler/Graph/IR.cpp b/TensorFrost/Compiler/Graph/IR.cpp deleted file mode 100644 index 5eb497c1..00000000 --- a/TensorFrost/Compiler/Graph/IR.cpp +++ /dev/null @@ -1,226 +0,0 @@ -#include "IR.h" - -namespace TensorFrost { - -int IR::max_kernel_memory_dependencies = 12; -int IR::max_allowed_memory_dependencies = 12; - -void IR::RemoveNode(Node* node) { - if (node->valid()) { - //remove itself from all outputs - auto outputs = node->args.Outputs(); - for (auto& [id, output_node] : outputs) { - output_node->args.Remove(id.first); - } - - //remove all inputs - auto inputs = node->args.InputsCopy(); - for (auto& [id, input_node] : inputs) { - node->args.Remove(id); - } - - // if its the child of its parent, update the parent's child - if (node->parent && node->parent->child == node) { - node->parent->child = node->next; - } - - // if child node exists, iterate through it and remove all children - if (node->child) { - vector to_delete; - for (auto child = NodeIterator(node); !child.end(); child.next()) { - to_delete.push_back(*child); - } - for (int i = (int)to_delete.size() - 1; i >= 0; i--) { - RemoveNode(to_delete[i]); - } - } - - // if direct child of its parent - if (node->parent && node->parent->child == node) { - node->parent->child = node->next; - } else if (node->prev) { - node->prev->next = node->next; - } - - node->next->prev = node->prev; - existing_nodes.erase(node); - delete node; - } -} - -#ifdef _RELWITHDEBINFO -//#define PROFILE_COMPILATION -#endif - -void IR::RunCompilationPass(string pass_name, const function& expression, bool print, bool update_graph) { - current_pass = pass_name; -#ifdef PROFILE_COMPILATION - auto start = std::chrono::high_resolution_clock::now(); -#endif - - try { - expression(); - } catch (const std::exception& e) { - CheckIR(pass_name, false, false); - throw std::runtime_error("Error in compilation pass " + pass_name + ": " + e.what()); - } - -#ifdef PROFILE_COMPILATION - auto end = std::chrono::high_resolution_clock::now(); - float duration = (float) std::chrono::duration_cast(end - start).count() / 1000.0f; - - PassStats stats; - stats.pass_name = pass_name; - stats.duration = duration; - stats.node_count = 0; - for (auto node = begin(); !node.end(); node.next()) { - stats.node_count++; - } - pass_stats.push_back(stats); -#endif - - if (update_graph) { - UpdateGraph(); - } - - if (print) { - CheckIR(pass_name, false, false); - } - current_pass = ""; -} - -#define MAX_COMPILATION_ITERATIONS 32 -#define LOAD_FUSION - -bool IR::RunIterativeCompilationPass(string pass_name, int max_iterations, const function& expression, bool print, bool update_graph) { - bool anything_happened = false; - RunCompilationPass(pass_name, [&]() { - bool converged = false; - for (int i = 0; i < max_iterations; i++) { - bool finished = expression(); - if (finished) { - converged = true; - break; - } else { - anything_happened = true; - } - } - if (!converged) { - throw std::runtime_error("Failed to converge on " + pass_name + ", exceeded maximum iterations"); - } - }, print, update_graph); - return anything_happened; -} - -void IR::CompileIR() -{ - // TODO (Moroz): Add auto tests into build system - CheckIR("Input", false, false); - RunCompilationPass("InitialCompilation", [&]() { - RunCompilationPass("GetInputList", [&]() { GetInputList(); }); - RunCompilationPass("OptimizeOperations", [&]() { OptimizeOperations(); }); - RunCompilationPass("RemoveUnusedOperations", [&]() { RemoveUnusedOperations(); }, true); - RunCompilationPass("UnrollLoops", [&]() { UnrollLoops(); }, true); - RunCompilationPass("TryReplaceModificationsWithVersions", [&]() { TryReplaceModificationsWithVersions(); }, true); - RunCompilationPass("RemoveUnusedOperations", [&]() { RemoveUnusedOperations(); }, true); - - RunIterativeCompilationPass("InsertAlgorithmicPrimitives_PreAutodiff", MAX_COMPILATION_ITERATIONS, [&]() { - return InsertAlgorithmicPrimitives(true); - }, true); - - RunIterativeCompilationPass("ComputeAutodiff", MAX_COMPILATION_ITERATIONS, [&]() { - return ComputeAutodiff(); - }); - - RunCompilationPass("RemoveUnusedOperations", [&]() { RemoveUnusedOperations(); }, true); - RunCompilationPass("UnrollAtomicOperations", [&]() { UnrollAtomicOperations(); }); - RunCompilationPass("OptimizeReductions", [&]() { OptimizeReductions(); }, true); - - RunIterativeCompilationPass("InsertAlgorithmicPrimitives_PostAutodiff", MAX_COMPILATION_ITERATIONS, [&]() { - return InsertAlgorithmicPrimitives(false); - }, true); - - RunCompilationPass("TryReplaceModificationsWithVersions", [&]() { TryReplaceModificationsWithVersions(); }); - RunCompilationPass("OptimizeOperations", [&]() { OptimizeOperations(); }); - RunCompilationPass("RemoveUnusedOperations", [&]() { RemoveUnusedOperations(); }, true); - }, true); - - RunCompilationPass("KernelGeneration", [&]() { - RunCompilationPass("SeparateOperationsIntoKernels", [&]() { SeparateOperationsIntoKernels(); }, true); - RunCompilationPass("CheckKernelShapes", [&]() { CheckKernelShapes(); }); - - RunCompilationPass("ReorderOperations", [&]() { ReorderOperations(); }); - RunCompilationPass("MoveShapeOutsideKernels", [&]() { MoveShapeOutsideKernels(); }); - RunCompilationPass("OptimizeKernels", [&]() { OptimizeKernels(); }); - RunCompilationPass("OptimizeHost", [&]() { OptimizeHost(); }); - - RunCompilationPass("UnrollLoops", [&]() { UnrollLoops(4); }); - RunCompilationPass("TryReplaceModificationsWithVersions", [&]() { TryReplaceModificationsWithVersions(); }, true); - RunCompilationPass("RemoveUnusedOperations", [&]() { RemoveUnusedOperations(); }); - RunCompilationPass("CheckKernelShapes", [&]() { CheckKernelShapes(); }); - - RunCompilationPass("UpdateKernelShapes", [&]() { UpdateKernelShapes(); }, true); - }, true); - -#ifdef LOAD_FUSION - RunIterativeCompilationPass("Iterative load fusion", MAX_COMPILATION_ITERATIONS, [&]() { - AddKernelGlobalLoadOperations(); - AddMemoryOpIndices(); - bool no_changes = OptimizeKernelLoadOperations(); - if(!no_changes) { - RemoveUnusedOperations(); - } - return no_changes; - }); -#endif - - RunCompilationPass("Reclusterize", [&]() { - LimitKernelMemoryDependencies(); - }, true); - - RunCompilationPass("FinilizeKernels", [&]() { - RunCompilationPass("AddKernelGlobalStoreOperations", [&]() { AddKernelGlobalStoreOperations(); }); - RunCompilationPass("AddKernelGlobalStoreOperations: RemoveUnusedKernels", [&]() { RemoveUnusedKernels(); }, true); - RunCompilationPass("AddMemoryOpIndices", [&]() { AddMemoryOpIndices(); }); - RunCompilationPass("ReorderOperations", [&]() { ReorderOperations(); }); - RunCompilationPass("OptimizeOperations", [&]() { OptimizeOperations(); }); - RunCompilationPass("AddMemoryOpIndices", [&]() { AddMemoryOpIndices(); }, true); - - //TODO: Add support for unrolling small constant kernel dimensions - //RunCompilationPass("UnrollOperations", [&]() { UnrollOperations(); }, true); - //RunCompilationPass("SqueezeKernelShapes", [&]() { SqueezeKernelShapes(); }); - - RunCompilationPass("FinalizeMemoryIndexing", [&]() { FinalizeMemoryIndexing(); }); - RunCompilationPass("RemoveUnusedOperations", [&]() { RemoveUnusedOperations(); }); - RunCompilationPass("OptimizeKernels", [&]() { OptimizeKernels(); }); - RunCompilationPass("OptimizeHost", [&]() { OptimizeHost(); }); - RunCompilationPass("OptimizeOperations", [&]() { OptimizeOperations(); }); - RunCompilationPass("OptimizeHostValuesWithHints", [&]() { OptimizeHostValuesWithHints(); }); - RunCompilationPass("RemoveUnusedOperations", [&]() { RemoveUnusedOperations(); }); - - RunCompilationPass("RemoveUnusedKernels", [&]() { RemoveUnusedKernels(); }, true); - RunCompilationPass("AddMemoryDeallocation", [&]() { AddMemoryDeallocation(); }, true); - //RunCompilationPass("CheckFinalKernelShapes", [&]() { CheckKernelShapes(); }); - }, true); - - RunCompilationPass("GetOutputList", [&]() { GetOutputList(); }); - RunCompilationPass("ComputeStatistics", [&]() { ComputeStatistics(); }); - -#ifdef PROFILE_COMPILATION - cout << "Profiled compilation passes:" << endl; - for (const PassStats& stats : pass_stats) { - cout << "Pass: " << stats.pass_name << " took " << stats.duration << "ms and processed " << stats.node_count << " nodes" << endl; - } -#endif -} - -int GetAxis(int dims, int axis) -{ - if (axis < 0) - { - axis = dims + axis; - } - return axis; -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Compiler/Graph/IR.h b/TensorFrost/Compiler/Graph/IR.h deleted file mode 100644 index b0629ec9..00000000 --- a/TensorFrost/Compiler/Graph/IR.h +++ /dev/null @@ -1,413 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "Compiler/Operations.h" -#include "Utility/Utility.h" -#include "Node.h" -#include "Scope.h" - -namespace TensorFrost { - -class IR { -public: - Node* root; - NodeIterator cursor; - - IR() { - root = new Node(); - root->index_ = 0; - root->name = "root"; - root->initialize(nullptr, {}, "host", {TFType::None, 0}, existing_nodes, true); - cursor = NodeIterator(root); - } - - ~IR() { - vector to_delete; - for (auto node = begin(); !node.end(); node.next()) { - to_delete.push_back(*node); - } - for (Node* node : to_delete) { - delete node; - } - delete root; - } - - NodeIterator begin() const { - return NodeIterator(root); - } - - Node* AddNode(Tensor* tensor, NodeArguments&& args, string&& name, TFDataFormat type) { - Node* newNode = nullptr; - if (cursor->valid()) { //already initialized, add new node before cursor - newNode = new Node(cursor->prev, cursor->parent); - if (cursor->prev) - cursor->prev->next = newNode; - else if (cursor->parent) - cursor->parent->child = newNode; - cursor->prev = newNode; - newNode->next = *cursor; - newNode->initialize(tensor, std::move(args), std::move(name), type, existing_nodes); - } else { - newNode = cursor.get(); - cursor->initialize(tensor, std::move(args), std::move(name), type, existing_nodes); - cursor.go_to_next(); - } - -#ifndef NDEBUG - newNode->created_in_pass = current_pass; - newNode->created_in_function = current_function; -#endif - - return newNode; - } - - void MoveNodeTo(Node* target_place, Node* note_to_move) { - if (note_to_move->valid()) { - //remove from current position - if (note_to_move->parent && note_to_move->parent->child == note_to_move) { - note_to_move->parent->child = note_to_move->next; - } - else if (note_to_move->prev) { - note_to_move->prev->next = note_to_move->next; - } - note_to_move->next->prev = note_to_move->prev; - - //insert into new position - note_to_move->parent = target_place->parent; - note_to_move->prev = target_place->prev; - note_to_move->next = target_place; - if (target_place->prev) { - target_place->prev->next = note_to_move; - } - else if (target_place->parent) { - target_place->parent->child = note_to_move; - } - target_place->prev = note_to_move; - } - } - - void RemoveNode(Node* node); - - void SetCursor(Node* node) { - if(node != nullptr) { - cursor = NodeIterator(node, root); - } else { - throw std::runtime_error("Cursor cannot be set to null"); - } - } - - bool LimitKernelMemoryDependencies(); - - void UnrollOperations(); - - stack scope_stack; - - void EndScope() { - if (scope_stack.empty()) throw std::runtime_error("No scope to end"); - SetCursor(scope_stack.top()); - scope_stack.pop(); - } - - void BeginScope(Node* node) { - scope_stack.push(*cursor); - SetCursor(node); - } - - void BeginScopeLastChild(Node* node) { - BeginScope(node->GetLastChild()); - } - - void ExecuteExpressionAfter(Node* node, const function&& expression) { - BeginScope(node->next); - expression(); - EndScope(); - } - - void ExecuteExpressionBefore(Node* node, const function&& expression) { - BeginScope(node); - expression(); - EndScope(); - } - - void ExecuteExpressionFirstChild(Node* node, const function&& expression) { - BeginScope(node->child); - expression(); - EndScope(); - } - - void ExecuteExpressionLastChild(Node* node, const function&& expression) { - BeginScopeLastChild(node); - expression(); - EndScope(); - } - - void CheckIR(string name, bool check_clustering, bool check_kernels); - string PrintListing(map node_debug) const; - - string GetNodeListing(Node *node) const; - - map CopyNodes(set nodes_to_copy, - unordered_map argument_replacements, - unordered_map indices, - unordered_set targets, bool must_copy_all); - map CopyComputation(const unordered_set& targets, - const unordered_map& indices); - void GetInputList(); - void GetOutputList(); - void ComputeStatistics(); - void CopyArguments(ArgEdges args_to_copy, Node *cursor); - map CopyNodesWithIndex(unordered_set nodes_to_copy, - unordered_map indices, Node* cursor = nullptr); - void ReorderOperations(); - void MoveShapeOutsideKernels(); - bool OptimizeKernels(); - void OptimizeHost(); - void OptimizeOperations(); - void OptimizeHostValuesWithHints(); - - bool OptimizeKernelLoadOperations(); - void OptimizeReductions(); - - unordered_set GetDependencies(unordered_set nodes); - - void RemoveUnusedOperations(); - - bool InsertAlgorithmicPrimitives(bool skip_differentiable); - void UnrollLoops(int max_iterations = 8); - void UnrollAtomicOperations(); - void TryReplaceModificationsWithVersions(); - bool ComputeAutodiff(); - void SeparateOperationsIntoKernels(); - void ComputeNodeCost(); - - map GetKernelOutputs(Node *kernel); - void AddNodeLoadOperations(Node* node, Node* kernel, Tensors indices); - void AddKernelGlobalLoadOperations(); - void AddMemoryOpIndices(); - void AddKernelGlobalStoreOperations(); - - unordered_set ComputeKernelDependencies(Node* kernel); - - void CheckKernelShapes(); - void UpdateKernelShapes(); - void AddMemoryDeallocation(); - void RunCompilationPass(string pass_name, const function &expression, bool print = false, bool update_graph = false); - - bool RunIterativeCompilationPass(string pass_name, int max_iterations, const function &expression, - bool print = false, - bool update_graph = false); - - void ReplaceDimNodes(Node* kernel, Tensors indices, int dims); - void MultiDimensionalModeIndices(vector& indices, Node* kernel_, - int dims, Tensors kernel_shape); - Tensor* LinearBlockModeIndices(Tensors& indices, Node* kernel_, int dims, - Tensors kernel_shape); - - void ComputeAddress(Node *node, Tensors indices); - - void FinalizeMemoryIndexing(); - void RemoveUnusedKernels(); - void CompileIR(); - - void UpdateIndex() { - int index = 0; - for (auto node = begin(); !node.end(); node.next()) { - node->UpdateEdges(); - node->index_ = index++; - } - - index++; //add root node - - if (index != existing_nodes.size()) { - unordered_set found_nodes; - found_nodes.insert(root); - for (auto node = begin(); !node.end(); node.next()) { - found_nodes.insert(*node); - } - - unordered_set missing_nodes; - for (auto node : existing_nodes) { - if (found_nodes.find(node) == found_nodes.end()) { - missing_nodes.insert(node); - } - } - - string missing_nodes_str = ""; - for (auto node : missing_nodes) { - missing_nodes_str += GetNodeListing(node) + "\n"; - } - - missing_nodes_str += "\n" + PrintListing({}); - - throw std::runtime_error("\n Some nodes got lost during indexing. Expected " + to_string(existing_nodes.size()) + " but got " + to_string(index) + ". Likely invalid graph. Missing nodes:\n" + missing_nodes_str); - } - } - - void UpdateGraph(const Node* uroot = nullptr) { - if (uroot == nullptr) { - uroot = root; - } - - UpdateIndex(); - -#ifdef _DEBUG - map invalid_nodes; - // check if graph is valid - for (auto node = NodeIterator(uroot); !node.end(); node.next()) { - // if there are null inputs throw an error - for (auto& [id, n] : (*node)->args.Inputs()) { - if (n == nullptr) { - throw std::runtime_error("Null input found in node " + (*node)->var_name + ". Likely an icorrectly deleted node."); - } else if (n->index_ > (*node)->index_) { //if input node is after current node, throw an error - invalid_nodes[*node] = "Argument " + TypeToString(id.first) + ":" + - to_string(id.second) + " " + n->var_name + " is after current node"; - } - } - } - - if(invalid_nodes.size() > 0) { -#ifdef NDEBUG - std::string error = "Invalid graph: "; - for (auto [node, message] : invalid_nodes) { - error += GetNodeListing(node) + ": " + message + "\n"; - } - throw std::runtime_error(error); -#else - throw std::runtime_error("Invalid graph: " + PrintListing(invalid_nodes)); -#endif - } - -#endif - - //update modified flags - for (auto node = NodeIterator(uroot); !node.end(); node.next()) { - node->flags.remove(NodeProp::Modified); - //go over all outputs and check if they are modifiers - for (auto [edge, to] : node->args.Outputs()) { - auto& [id, from] = edge; - if (to->op->HasAllTypes(OpProp::Modifier)) { - bool is_memory = false; - if (id.first != ArgType::Memory) { - is_memory = true; - } - if (!is_memory) { - node->flags.set(NodeProp::Modified); - break; - } - } - } - } - } - - vector GetNodesOfType(const string& name) const { - vector result; - for (auto node = begin(); !node.end(); node.next()) { - if (node->name == name) { - result.push_back(*node); - } - } - return result; - } - - template - vector GetNodesOfType(OpProp type, Args... args) const { - vector result; - for (auto node = begin(); !node.end(); node.next()) { - if (node->op->HasAllTypes(type, args...)) { - result.push_back(*node); - } - } - return result; - } - - size_t CountNodesOfType(OpProp type) const { - size_t count = 0; - for (auto node = begin(); !node.end(); node.next()) { - if (node->op->HasAllTypes(type)) { - count++; - } - } - return count; - } - - vector GetChildren(Node* node) const { - vector result; - for (auto child = NodeIterator(node); !child.end(); child.next()) { - result.push_back(*child); - } - return result; - } - - size_t CombineHashes(size_t hash1, size_t hash2) const { - return hash1 ^ (hash2 + 0x9e3779b9 + (hash1 << 6) + (hash1 >> 2)); - } - - size_t GetApproximateStateHash() const { - size_t hash = 0; - for (auto node = begin(); !node.end(); node.next()) { - size_t node_hash1 = std::hash{}(node->name); - size_t node_hash2 = std::hash{}(node->debug_name); - size_t node_hash3 = std::hash{}(node->debug_index); - hash = CombineHashes(hash, CombineHashes(node_hash1, CombineHashes(node_hash2, node_hash3))); - } - return hash; - } - - int input_memory_count = 0; - int output_memory_count = 0; - int temp_memory_count = 0; - - int readbacks = 0; - int writebacks = 0; - - unordered_map> shape_memory_map; - unordered_map input_memory_map; - unordered_map output_memory_map; - - string current_pass = "Tracing initial graph"; - string current_function = "None"; - - struct PassStats { - string pass_name; - float duration; - int node_count; - }; - vector pass_stats; - - void ReplaceArgs(const ArgEdges& edges, const map& replacements) { - edgesToUpdate.insert(edges.begin(), edges.end()); - replacementNodes.insert(replacements.begin(), replacements.end()); - } - - void RemoveNodes(const vector& nodes) { - removedNodes.insert(removedNodes.end(), nodes.begin(), nodes.end()); - } - - void ApplyChanges(bool update_graph = true, const Node *uroot = nullptr); - void ClearChanges(); - - ArgEdges edgesToUpdate{}; - map replacementNodes{}; - vector removedNodes{}; - unordered_set existing_nodes{}; - - static int max_kernel_memory_dependencies; - static int max_allowed_memory_dependencies; -}; - -int GetAxis(int dims, int axis); - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Compiler/Graph/Node.cpp b/TensorFrost/Compiler/Graph/Node.cpp deleted file mode 100644 index 262da7ec..00000000 --- a/TensorFrost/Compiler/Graph/Node.cpp +++ /dev/null @@ -1,365 +0,0 @@ -#include "IR.h" - -namespace TensorFrost { - -int Node::global_index = 0; - -void ArgumentManager::UpdateOutputs() { - for (auto& [id, node] : inputs_) { - node->args.AddOutput(id, node_); - } -} - -const Tensor *ArgumentManager::GetTensor(ArgType type, int index) const { - return Get(type, index)->GetTensor(); -} - -const Tensor & ArgumentManager::operator[](int index) const { - return *GetTensor(ArgType::Input, index); -} - -map ArgumentManager::GetTensors(ArgType type) const { - map tensors; - for (auto& [id, node] : inputs_) { - if (id.first == type) { - tensors[id.second] = node->GetTensor(); - } - } - return tensors; -} - -ArgumentManager::~ArgumentManager() { - -} - -bool ArgumentManager::CannotMoveArgument(ArgID id) { - Node* from = inputs_[id]; - Node* to = node_; - return (id.first == ArgType::Memory && - !to->op->HasAllTypes(OpProp::Set)) || - (id.first == ArgType::Shape && !to->op->HasAllTypes(OpProp::Memory)) || - from->op->HasAllTypes(OpProp::Memory) || - (from->name == "const" && to->op->HasAllTypes(OpProp::Memory)); //FIX THIS -} - -bool ArgumentManager::CannotCopyArgument(ArgID id) { - Node* from = inputs_[id]; - Node* to = node_; - bool shape = id.first == ArgType::Shape; - bool to_memory = to->op->HasAllTypes(OpProp::Memory); - bool shape_not_memory = shape && !to_memory; - bool is_output = from->flags.has(NodeProp::OutputMemory); - bool is_input = from->flags.has(NodeProp::InputMemory); - bool no_fusion = from->flags.has(NodeProp::StopFusion); - bool cannot_copy = id.first == ArgType::Memory || shape_not_memory || - from->op->HasAllTypes(OpProp::Static) || from->op->HasAllTypes(OpProp::Memory) || - from->flags.has(NodeProp::Modified) || is_output || is_input || no_fusion; - return cannot_copy; -} - -bool ArgumentManager::IsChangingInput(ArgID arg) { - return arg.first == ArgType::Memory && - node_->op->HasAllTypes(OpProp::Modifier); -} - -void ArgumentManager::UpdateArgument(ArgID id, Node *node) { - if(node == nullptr) { - throw std::runtime_error("ArgumentManager: Node is null"); - } - if(!Has(id)) { - throw std::runtime_error("ArgumentManager: Argument " + TypeToString(id.first) + ":" + std::to_string(id.second) + " with node " + node->name + " does not exist"); - } - // inputs_[id]->args.RemoveOutput(id, node_); - // inputs_[id] = node; - // argument_types_[id] = node->type; - // node->args.AddOutput(id, node_); - Remove(id); - AddArgument(id, node); -} - -Node * Node::GetChild(string name) { - for(NodeIterator it = NodeIterator(this); !it.end(); it.next()) { - if(it->name == name) { - return it.get(); - } - } - return nullptr; -} - -Node * Node::GetNodeWithCommonParent(Node *other) { - for (Node* cur_parent = this; cur_parent != nullptr; cur_parent = cur_parent->parent) { - if (cur_parent->parent == other->parent) { - return cur_parent; - } - } - for (Node* cur_parent = other; cur_parent != nullptr; cur_parent = cur_parent->parent) { - if (cur_parent->parent == this->parent) { - return cur_parent; - } - } - for (Node* cur_parent1 = this; cur_parent1 != nullptr; - cur_parent1 = cur_parent1->parent) { - for (Node* cur_parent2 = other; cur_parent2 != nullptr; - cur_parent2 = cur_parent2->parent) { - if (cur_parent1->parent == cur_parent2->parent) { - return cur_parent1; - } - } - } - throw std::runtime_error("No common parent found"); -} - -Node* Node::GetLastChild() { - NodeIterator it = NodeIterator(this); - for (; !it.end(); it.go_to_next()) {} - return it.get(); -} - -vector Node::GetChildren() { - vector children; - for(NodeIterator it = NodeIterator(this); !it.end(); it.go_to_next()) { - children.push_back(it.get()); - } - return children; -} - -bool Node::HasCommonParents(Node *other, int max_depth) const { - int depth = 0; - for (Node* cur_parent = parent; cur_parent != nullptr; cur_parent = cur_parent->parent) { - if (depth++ > max_depth) { - break; - } - if (!other->HasParent(cur_parent)) { - return false; - } - } - return true; -} - -bool Node::HasParent(string name) { - return GetParent(name) != this; -} - -bool Node::HasChild(string name) { - return GetChild(name) != nullptr; -} - -void Node::ValidateParentShapes() const { - //compare the shape of this node with the shape of all its parents - ShapeInfo shape = ShapeInfo(this); - for (Node* cur_parent = parent; cur_parent != nullptr; cur_parent = cur_parent->parent) { - ShapeInfo parent_shape = ShapeInfo(cur_parent); - ShapeCompareResult result = CompareShape(parent_shape, shape); //must only be broadcastable - if (!result.compatible) { - throw std::runtime_error(MakeNodeErrorMessage("The node " + debug_name + " (" + name + ") has incompatible shapes with its parent " + - cur_parent->debug_name + " (" + cur_parent->name + ")", {this, cur_parent})); - } - } -} - -void Node::SetMemoryType(NodeProp memory_type, int index) { - flags.set(memory_type, (int64_t)index); -} - -void Node::CheckNode() const { - // must have operation - if (op == nullptr) { - throw std::runtime_error("Operation object not found"); - } - - // must have tensor - if (tensor_ == nullptr && !flags.has(NodeProp::IsStatic)) { - throw std::runtime_error("Tensor not found"); - } - - //validate the shape of the node if its not scalar - if (args.Count(ArgType::Shape) > 0) { - ValidateParentShapes(); - } -} - -Node * Node::GetLastVersion(Node *latest_node) { - //find last store/scatter operation - Node* last_modifier = this; - int last_index = -1; - Node* loop_node = latest_node->GetParent("loop"); - bool has_loop = loop_node != latest_node; - for (auto [edge, to] : args.Outputs()) { - auto& [id, from] = edge; - bool is_memory = false; - if (id.first != ArgType::Memory) { - is_memory = true; - } - if (is_memory) { - continue; - } - if (to->op->HasAllTypes(OpProp::Modifier)) { - if (to->index_>last_index) { - // either find the last modifier or the last memory node - // or if there is a loop, find the last modifier inside the loop (i.e. - // the previous iteration's modifier) - // if the loop is scalar, then it doesn't matter - bool before_latest = to->index_ < latest_node->index_; - bool inside_loop = has_loop && to->HasParent(loop_node); - bool not_same = to != latest_node; - if ((before_latest || inside_loop) && not_same) - { - last_index = to->index_; - last_modifier = to; - } - } - } - } - return last_modifier; -} - -Node * Node::GetFinalVersion() { - Node* final_version = this; - int last_index = -1; - for (auto [edge, to] : args.Outputs()) { - auto& [id, from] = edge; - bool is_memory = false; - if (id.first != ArgType::Memory) { - is_memory = true; - } - if (is_memory) { - continue; - } - if (to->op->HasAllTypes(OpProp::Modifier) && !to->op->HasAllTypes(OpProp::MemoryOp)) { - if (to->index_ > last_index) { - last_index = to->index_; - final_version = to; - } - } - } - return final_version; -} - -const map flag_names = { - {NodeProp::Modified, "Modified"}, {NodeProp::Placeholder, "Placeholder"}, - {NodeProp::DetachGrad, "DetachGrad"}, {NodeProp::PassGrad, "PassGrad"}, - {NodeProp::KeepDims, "KeepDims"}, {NodeProp::IsStatic, "IsStatic"}, - {NodeProp::OutputMemory, "OutputMemory"}, {NodeProp::InputMemory, "InputMemory"}, - {NodeProp::InputMemoryList, "InputMemoryList"}, {NodeProp::InputShapeMemory, "InputShapeMemory"}, - {NodeProp::InputShapeDim, "InputShapeDim"}, {NodeProp::NoCopyFusion, "NoCopyFusion"}, - {NodeProp::NoLoadFusion, "NoLoadFusion"}, {NodeProp::StopFusion, "StopFusion"}, {NodeProp::HintMaxValue, "HintMaxValue"}, - {NodeProp::HintMinValue, "HintMinValue"}, {NodeProp::LocalMemoryOp, "LocalMemoryOp"} -}; - -string NodeFlagsToString(NodeProp flags) { - if (!flag_names.contains(flags)) { - throw std::runtime_error("Flag name not defined"); - } - return flag_names.at(flags); -} - -void Node::UpdateEdges() { - if (!child) child = new Node(nullptr, this); - if (!next) next = new Node(this, parent); - if (child->valid()) { - child->parent = this; - } - if (next->valid()) { - next->prev = this; - next->parent = parent; - } -} - -const map indexing_mode_names = { - {IndexingMode::Clamp, "Clamp"}, {IndexingMode::Repeat, "Repeat"},{IndexingMode::Unsafe, "Unsafe"} -}; - -string IndexingModeToString(IndexingMode mode) { - return indexing_mode_names.at(mode); -} - -void Node::initialize(Tensor *tensor, NodeArguments &&new_args, string &&new_name, TFDataFormat new_format, unordered_set& existing_nodes, bool set_static) { - if(valid()) { - throw runtime_error("Node already initialized"); - } - UpdateEdges(); - flags.remove(NodeProp::Placeholder); - - tensor_ = tensor; - format = new_format; - args.AddArguments(std::move(new_args)); - //args.UpdateOutputs(); - flags.set(NodeProp::IsStatic, set_static); - name = std::move(new_name); - op = FindOperation(name); - CheckNode(); - existing_nodes.insert(this); -} - -void Node::CopyProperties(Node *other) { - name = other->name; - debug_name = other->debug_name; - indexing_mode_ = other->indexing_mode_; - group_size = other->group_size; - format = other->format; - - flags.copy_all(other->flags); -} - -void Node::CopyMetadata(Node *other) { - if (other->debug_name != "") { - debug_name = other->debug_name; - } - if(other->indexing_mode_ != IndexingMode::Clamp) { - indexing_mode_ = other->indexing_mode_; - } - group_size = other->group_size; - - flags.copy_all_except(other->flags, {NodeProp::Modified}); -} - -int Node::ComputeDepth(Node *root) const { - int depth = 0; - for (const Node* node = this; node != root; node = node->parent) { - depth++; - } - return depth; -} - -bool Node::HasParent(Node *node) const { - for (Node* cur_parent = parent; cur_parent != nullptr; cur_parent = cur_parent->parent) { - if (cur_parent == node) { - return true; - } - } - return false; -} - -void Node::ReplaceThisWithGivenNode(Node *replacement, int min_index, bool make_modified, bool copy_metadata, set nodes_to_modify) { - try { - for (auto [edge, to] : args.OutputsCopy()) { - auto& [id, from] = edge; - if(nodes_to_modify.size() > 0 && !nodes_to_modify.contains(to->name)) { - continue; - } - if (to->index_ >= min_index) { - if(make_modified) { - replacement->flags.set(NodeProp::Modified); - } - to->args.UpdateArgument(id, replacement); - } - } - - if(copy_metadata) { - replacement->CopyMetadata(this); - this->flags.clear(); - } - } catch (const std::exception& e) { - throw std::runtime_error(MakeNodeErrorMessage("Failed to replace node with another node: " + std::string(e.what()), {this, replacement})); - } -} - -Node * Node::GetParent(string name) { - for (Node* cur_parent = parent; cur_parent != nullptr; cur_parent = cur_parent->parent) { - if (cur_parent->name == name) { - return cur_parent; - } - } - return this; -} -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Compiler/Graph/Node.h b/TensorFrost/Compiler/Graph/Node.h deleted file mode 100644 index 9943da24..00000000 --- a/TensorFrost/Compiler/Graph/Node.h +++ /dev/null @@ -1,289 +0,0 @@ -#pragma once - -#include "Compiler/Operations.h" -#include "Utility/Utility.h" -#include "Arguments.h" - -namespace TensorFrost { - -enum class NodeProp { - Placeholder, - Modified, - IsStatic, - OutputMemory, - InputShapeDim, - InputShapeMemory, - InputMemory, - InputMemoryList, - KeepDims, - DetachGrad, - PassGrad, - NoLoadFusion, - NoCopyFusion, - StopFusion, - HintMaxValue, - HintMinValue, - LocalMemoryOp, - Count, -}; - -enum class MemoryType { - None, - Input, - Output, - Constant, -}; - -enum class IndexingMode { - Unsafe, - Clamp, - Repeat -}; - -struct TFTypeDesc { - TFType type; - IndexingMode indexing_mode; - uint constant_value; - -}; - -string NodeFlagsToString(NodeProp flag); -string IndexingModeToString(IndexingMode mode); -using NodeProps = FlagSet; - -class Node { - static int global_index; - public: - int debug_index = -1; - string name; - string var_name = ""; - string debug_name; - int index_ = -1; - float cost_ = -1.0f; - unordered_set memory_deps; - - //Edge *prev, *next; - // - Node *parent = nullptr, *child = nullptr, *next = nullptr, *prev = nullptr; - const Operation* op; - NodeProps flags; - ArgumentManager args; - const Tensor* tensor_; - TFDataFormat format = {TFType::Float, 32}; - std::vector data; - IndexingMode indexing_mode_; //clamp unless otherwise specified - vector group_size; //kernel properties - -#ifndef NDEBUG - string created_in_pass; - string created_in_function; -#endif - - Node(Node* prev = nullptr, Node* parent = nullptr) : parent(parent), prev(prev), args(this) { - flags.set(NodeProp::Placeholder); - debug_index = global_index++; - indexing_mode_ = IndexingMode::Clamp; - } - - bool valid() { - return !flags.has(NodeProp::Placeholder); - } - - void UpdateEdges(); - - //initialize and create next/child placeholders - void initialize(Tensor* tensor, NodeArguments&& new_args, string&& new_name, TFDataFormat new_format, unordered_set& existing_nodes, bool set_static = false); - - void CopyProperties(Node* other); - void CopyMetadata(Node* other); - - const Tensor* GetTensor() const; - int ComputeDepth(Node* root = nullptr) const; - bool HasParent(Node* node) const; - - /// - /// Make all outputs of this node use the given node as input, assuming that the output is further than min_index - /// - /// - /// - void ReplaceThisWithGivenNode(Node* replacement, int min_index = -1, bool make_modified = false, bool copy_metadata = true, set nodes_to_modify = {}); - - Node* GetParent(string name); - Node* GetChild(string name); - - //get the parent that has a common first parent with another node - Node* GetNodeWithCommonParent(Node* other); - - Node* GetLastChild(); - vector GetChildren(); - - //checks if the other node has all parents as this node - bool HasCommonParents(Node* other, int max_depth = 128) const; - - bool HasParent(string name); - bool HasChild(string name); - - void ValidateParentShapes() const; - - void SetMemoryType(NodeProp memory_type, int index = 0); - void CheckNode() const; - Node* GetLastVersion(Node* latest_node); - Node* GetFinalVersion(); - - ~Node(); -}; - -//NodeIterator is a depth first iterator that iterates through the child nodes of a root node -class NodeIterator { - public: - Node* currentNode; - Node* currentParent; - Node* root; - -#ifndef NDEBUG - int iteration_count = 0; - int parent_inconsistency_count = 0; - unordered_set visited; - vector path; -#endif - - NodeIterator() : currentNode(nullptr), root(nullptr), currentParent(nullptr) {} - NodeIterator(Node* node, Node* root) : currentNode(node), root(root), currentParent(node->parent) {} - NodeIterator(const Node* node, const Node* root) - : currentNode(const_cast(node)), root(const_cast(root)), currentParent(const_cast(node->parent)) {} - NodeIterator(Node* node_root) - : currentNode(node_root->child), root(node_root), currentParent(node_root) {} - NodeIterator(const Node* node_root) - : currentNode(const_cast(node_root->child)), - root(const_cast(node_root)), - currentParent(const_cast(node_root)) {} - - Node* operator*() const { return currentNode; } - void update_current_node(Node* new_node) { - currentNode = new_node; -#ifndef NDEBUG - if (visited.contains(currentNode)) { - throw std::runtime_error("Node already visited, potential cycle in operation graph"); - } - visited.insert(currentNode); - path.push_back(currentNode); - iteration_count++; -#endif - } - - NodeIterator& go_to_next() { - if (!currentNode) { - throw std::runtime_error("Invalid node"); - } - - update_current_node(currentNode->next); -#ifndef NDEBUG - if (currentNode->parent != currentParent) { - parent_inconsistency_count++; - } -#endif - - return *this; - } - - NodeIterator& go_to_parent() { - if (!currentNode) { - throw std::runtime_error("Invalid node"); - } - - update_current_node(currentNode->parent); -#ifndef NDEBUG - currentParent = currentNode->parent; -#endif - - return *this; - } - - NodeIterator& go_to_child() { - if (!currentNode) { - throw std::runtime_error("Invalid node"); - } - -#ifndef NDEBUG - currentParent = currentNode->parent; -#endif - update_current_node(currentNode->child); - - return *this; - } - - NodeIterator& up() { - if (!currentNode) { - throw std::runtime_error("Invalid node"); - } - - if (root == currentNode) { - throw std::runtime_error("Already at root"); - } - - if (currentNode->parent != root) { - go_to_parent(); - } else { - go_to_next(); - } - return *this; - } - - NodeIterator& forward() { - if (!currentNode) { - throw std::runtime_error("Invalid node"); - } - - if (!currentNode->valid()) { - return *this; - } - - if (!currentNode->next->valid()) { // no next, try going up - Node* parent = currentNode->parent; - while (!parent->next->valid() && root != parent) { - parent = parent->parent; - } - if (root != parent) { // go to next sibling - currentNode = parent; - } - } - - // just go to next node and stop if it's the end - go_to_next(); - return *this; - } - - // first child, then next - NodeIterator& next() { - if (!currentNode) { - throw std::runtime_error("Invalid node"); - } - - if (!currentNode->valid()) { - return *this; - } - - if (currentNode->child->valid()) { // has child, go down - go_to_child(); - return *this; - } - - forward(); - - return *this; - } - - bool end() { return !currentNode->valid(); } - - Node* operator->() { return currentNode; } - - Node* get() { return currentNode; } - - int depth() { return currentNode->ComputeDepth(root); } - - bool operator!=(const Node* node) { return currentNode != node; } -}; - -std::string MakeNodeErrorMessage(std::string message, std::initializer_list nodes); - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Compiler/Graph/Scope.cpp b/TensorFrost/Compiler/Graph/Scope.cpp deleted file mode 100644 index 001ffc3e..00000000 --- a/TensorFrost/Compiler/Graph/Scope.cpp +++ /dev/null @@ -1,325 +0,0 @@ -#include "Scope.h" - -namespace TensorFrost { - -inline bool KernelScope::IsBoundary(const Node* input, const Node* output, - ArgType arg_type, bool is_identity) { - const Operation* input_op = input->op; - const Operation* output_op = output->op; - - // if this node loads something from another node, that node must not be in - // this kernel - if (output_op->HasAllTypes(OpProp::Load, OpProp::MemoryOp)) { - return arg_type == ArgType::Memory; - } - - // if we are modifying memory, then the modified memory must not be in the - // kernel - if (output_op->HasAnyType(OpProp::Scatter, OpProp::Store) && - !input_op->HasAnyType(OpProp::Scatter, OpProp::Store)) { - return arg_type == ArgType::Memory; - } - - //if the input is a scatter, then this is a boundary - if (input_op->HasAnyType(OpProp::Scatter)) { - //but multiple scatters can be in the same kernel - return !(output_op->HasAnyType(OpProp::Scatter)); - } - - // shape should not be inside kernels - if (arg_type == ArgType::Shape) { - return true; - } - - return false; -} - -bool KernelScope::IsValid() const { - // check if the scope is valid - if (begin == nullptr || end == nullptr) return false; - - // begin and end must have the same parent - if (begin->parent != end->parent) return false; - - if (begin->index_ < 0 || end->index_ < 0) - throw std::runtime_error("Indices are not computed"); - - // begin must be before or at the same index as end - if (begin->index_ > end->index_) return false; - - // check if the boundary nodes are not in the scope - for (Node* boundary_node : boundary_nodes) { - if (boundary_node->index_ >= begin->index_ && - boundary_node->index_ <= end->index_) { - return false; - } - } - - return true; -} - -KernelScope::KernelScope(Node* node, - unordered_set& output_scopes) : begin(node), end(node) { - scope_shape = ShapeInfo(node); - - // if host only, then this can not be a valid kernel scope - if (node->op->HasAllTypes(OpProp::HostOnly)) { - begin = nullptr; - end = nullptr; - return; - } - - if(node->op->HasAllTypes(OpProp::Modifier, OpProp::MemoryOp) && !node->op->HasAllTypes(OpProp::Scatter)) { - if(scope_shape.dim == 0) { - //must be at 1d scalar to properly generate the kernel with a non-atomic write operation - //technically this is also applicable to scatter operations, but when making an atomic counter, it is sometimes nice for its shape to be implicitly - //determined by the shape of the neighboring tensors - scope_shape.ExpandDimensionsTo(1); - } - } - - // find boundary nodes - bool identity = node->args.Count(ArgType::Index) == 0; - - for (auto& input : node->args.Inputs()) { - // get latest input version - Node* latest = input.second->GetLastVersion(node); - // check if input is the boundary of this kernel - bool is_loop_boundary = latest->index_ > node->index_; - if (IsBoundary(latest, node, input.first.first, identity)) { - if (is_loop_boundary) { - latest = latest->GetParent("loop"); - } - boundary_nodes.insert(latest); - } - } - - pair, bool> all_scopes = ComputeScopes(node); - auto child_scopes = all_scopes.first; - bool host_only = all_scopes.second; - - output_scopes.insert(child_scopes.begin(), child_scopes.end()); - - if(host_only) { - begin = nullptr; - end = nullptr; - return; - } - - int scope_count = (int)child_scopes.size(); - if (scope_count == 0) return; - - //if there is more than one child scope, then this node can not be in the scope - if (scope_count > 1) { - begin = nullptr; - end = nullptr; - return; - } - - KernelScope* child_scope = *child_scopes.begin(); - AddBoundaryNodes(child_scope->boundary_nodes); - - ShapeCompareResult result = - CompareShape(scope_shape, child_scope->scope_shape, true); - - if (result.compatible) { - scope_shape = result.broadcast_shape; - } else { - throw std::runtime_error("Something went wrong"); - } -} - -pair, bool> KernelScope::ComputeScopes(Node *root) { - std::unordered_set scopes; - KernelScope* current_scope = new KernelScope(); - bool host_only = false; - for (auto node = NodeIterator(root); !node.end(); node.go_to_next()) { - std::unordered_set child_scopes; - KernelScope* node_scope = new KernelScope(node.get(), child_scopes); - if (node_scope->IsValid()) { // can be merged - KernelScope* merged = KernelScope::Merge(current_scope, node_scope); - if (merged->IsValid()) { - current_scope = merged; - } else { - bool current_is_valid = current_scope->IsValid(); - if (current_is_valid) { - scopes.insert(current_scope); - } - current_scope = node_scope; - } - } else { // has child kernels - // add all child scopes - scopes.insert(child_scopes.begin(), child_scopes.end()); - // add current scope - bool current_is_valid = current_scope->IsValid(); - if (current_is_valid) { - scopes.insert(current_scope); - } - // create a new empty scope - current_scope = new KernelScope(); - host_only = true; - } - } - if (current_scope->IsValid()) { - scopes.insert(current_scope); - } - return {scopes, host_only}; -} - -KernelScope* KernelScope::Merge(KernelScope* a, KernelScope* b) { - bool a_valid = a->IsValid(); - bool b_valid = b->IsValid(); - - if (!a_valid && !b_valid) - throw std::runtime_error("Invalid kernel scopes for merging"); - - if (!a_valid) return b; - if (!b_valid) return a; - - if (a->end->next != b->begin) - throw std::runtime_error("Trying to merge non-adjacent kernel scopes"); - - ShapeCompareResult result = CompareShape(a->scope_shape, b->scope_shape); - - if (!result.exactly_compatible) return new KernelScope(); - - KernelScope* new_scope = new KernelScope(a->begin, b->end, result.broadcast_shape, a->boundary_nodes); - new_scope->AddBoundaryNodes(b->boundary_nodes); - - if (!new_scope->IsValid()) return new KernelScope(); - - return new_scope; -} - -//if shape nodes are compatible, then return the broadcast shape, if not return nullptr -ShapeDimCompareResult CompareShapeDim(Node* a_node, Node* b_node) { - ShapeDimCompareResult result; - result.compatible = false; - result.broadcast = false; - result.exactly_compatible = true; - result.unroll_compatible = true; - result.a_dim = -1; - result.b_dim = -1; - result.unroll_dim = nullptr; - result.broadcast_dim = nullptr; - - if (a_node->name == "const") result.a_dim = a_node->data[0]; - if (b_node->name == "const") result.b_dim = b_node->data[0]; - - // if one of the nodes is a constant = 1, then it is a broadcast - if ((result.a_dim == 1 || result.b_dim == 1) && !(result.a_dim == 1 && result.b_dim == 1)) { - result.compatible = true; - result.broadcast = true; - result.exactly_compatible = false; - - if (result.a_dim == 1) { - result.unroll_dim = a_node; - result.broadcast_dim = b_node; - } else { - result.unroll_dim = b_node; - result.broadcast_dim = a_node; - } - } - - if(!result.compatible) { - // if a and b are constants, then compare their values - if (result.a_dim != -1 && result.b_dim != -1) { - result.compatible = result.a_dim == result.b_dim; - } - } - - if(!result.compatible) { - // otherwise, if a and b are not the same node then they are not the same - // shape (possibly) - result.compatible = a_node == b_node; - } - - if(result.unroll_dim == nullptr) { - result.unroll_dim = a_node; - } - - if(result.broadcast_dim == nullptr) { - result.broadcast_dim = a_node; - } - - if(result.broadcast) { - if(result.a_dim > MAX_DIM_UNROLL || result.b_dim > MAX_DIM_UNROLL) { - result.unroll_compatible = false; - } - } - - result.exactly_compatible = result.compatible && result.exactly_compatible; - result.unroll_compatible = result.compatible && result.unroll_compatible; - return result; -} - -ShapeCompareResult CompareShape(ShapeInfo& a, ShapeInfo& b, bool throw_error) { - ShapeCompareResult result; - result.compatible = true; - result.exactly_compatible = true; - result.unroll_compatible = true; - result.broadcast = false; - result.a_dim = a.dim; - result.b_dim = b.dim; - result.broadcast_dim = max(a.dim, b.dim); - ShapeInfo& max_dim_shape = a.dim > b.dim ? a : b; - - int min_dim = min(a.dim, b.dim); - - for (int i = 0; i < min_dim; i++) { - Node* a_node = a[i]; - Node* b_node = b[i]; - int broadcast_index = i; - - ShapeDimCompareResult res = CompareShapeDim(a_node, b_node); - - if(!res.compatible) { - result.compatible = false; - if (throw_error) { - if(res.a_dim != -1 && res.b_dim != -1) { - throw std::runtime_error("Shapes are not compatible for nodes: " + a.name + " and " + b.name + " with constant values " + to_string(res.a_dim) + " and " + to_string(res.b_dim) + " at index " + to_string(i)); - } - throw std::runtime_error("Shapes are potentially not compatible for nodes: " + a.name + " and " + b.name + " at index " + to_string(i)); - } - break; - } - - if(res.broadcast) { - result.broadcast = true; - result.broadcast_dims.insert(broadcast_index); - } - - result.unroll_compatible = result.unroll_compatible && res.unroll_compatible; - result.exactly_compatible = result.exactly_compatible && res.exactly_compatible; - result.broadcast_shape.AddShape(broadcast_index, res.broadcast_dim); - result.unroll_shape.AddShape(broadcast_index, res.unroll_dim); - } - - //add the rest of the broadcast shape - if(result.compatible) { - for (int i = min_dim; i < result.broadcast_dim; i++) { - result.broadcast = true; - result.broadcast_shape.AddShape(i, max_dim_shape[i]); - result.unroll_shape.AddShape(i, max_dim_shape[i]); - result.broadcast_dims.insert(i); - } - } - - if((result.broadcast && min_dim > 0) || !result.compatible) { - result.exactly_compatible = false; - } - - if (result.compatible && result.broadcast_shape.dim != result.broadcast_dim) { - throw std::runtime_error("Internal Error: Broadcast shape does not match the broadcast dim"); - } - - return result; -} - -ShapeCompareResult CompareShape(const Node* a, const Node* b, bool throw_error) { - ShapeInfo a_info = ShapeInfo(a); - ShapeInfo b_info = ShapeInfo(b); - return CompareShape(a_info, b_info, throw_error); -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Compiler/Graph/Scope.h b/TensorFrost/Compiler/Graph/Scope.h deleted file mode 100644 index 79ba7c7a..00000000 --- a/TensorFrost/Compiler/Graph/Scope.h +++ /dev/null @@ -1,173 +0,0 @@ -#pragma once - -#include "Compiler/Operations.h" -#include "Utility/Utility.h" -#include "Node.h" - -#define MAX_DIM_UNROLL 4 - -namespace TensorFrost { - -using Tensors = vector; - -class ShapeInfo { - public: - vector> shape; - int dim = 0; - string name; - - ShapeInfo() {} - - ShapeInfo(ShapeInfo* shape_info) { - shape = shape_info->shape; - dim = shape_info->dim; - name = shape_info->name; - } - - ShapeInfo(const Node* node) { - dim = node->args.Count(ArgType::Shape); - for (int i = 0; i < dim; i++) { - AddShape(i, node->args.Get(ArgType::Shape, i)); - } - this->name = node->var_name != "" ? node->var_name : node->name; - } - - Node* operator[](int index) const { - return shape[index].first; - - } - - void AddShape(int index, Node* node) { - if(shape.size() <= index) { - shape.resize(index + 1); - } - shape[index] = {node, false}; - dim = max(dim, index + 1); - } - - bool CheckValidity(bool throw_error = false) const { - for (auto node : shape) { - if(node.first == nullptr) { - if (throw_error) { - throw std::runtime_error("Shape not fully defined"); - } - return false; - } - } - return true; - } - - const Tensor* GetTensor(int index) const { - CheckValidity(true); - return shape[index].first->GetTensor(); - } - - bool IsExpanded(int index) const { - return shape[index].second; - } - - Tensors GetTensors() const { - CheckValidity(true); - Tensors tensors = Tensors(); - for (auto node : shape) { - tensors.push_back(node.first->GetTensor()); - } - return tensors; - } - - NodeArguments GetArguments() const { - CheckValidity(true); - NodeArguments arguments; - for (int i = 0; i < shape.size(); i++) { - arguments[ArgID(ArgType::Shape, i)] = shape[i].first; - } - return arguments; - } - - vector GetShape(int default_value = 256) const; - - static float GetSizeEstimate(ShapeInfo &shape); - - void InsertDim(int index, Node* node, bool expanded = false) { - if (index >= shape.size()+1) { - shape.resize(index + 1); - } - shape.insert(shape.begin() + index, {node, expanded}); - dim++; - } - - void ExpandDimensionsTo(int new_dim); -}; - -struct ShapeCompareResult { - bool compatible; - bool unroll_compatible; - bool exactly_compatible; - ShapeInfo broadcast_shape; - ShapeInfo unroll_shape; - set broadcast_dims; - bool broadcast; - int broadcast_dim; - int a_dim; - int b_dim; - int min_dim; -}; - -struct ShapeDimCompareResult { - bool compatible; - bool unroll_compatible; - bool exactly_compatible; - Node* broadcast_dim; - Node* unroll_dim; - bool broadcast; - int a_dim; - int b_dim; -}; - -ShapeCompareResult CompareShape(const Node* a, const Node* b, - bool throw_error = false); - -ShapeCompareResult CompareShape(ShapeInfo& a, ShapeInfo& b, - bool throw_error = false); - -ShapeDimCompareResult CompareShapeDim(Node* a_node, Node* b_node); - -/// -/// Class to select kernel scopes from the IR graph given the constraints and the root node -/// -class KernelScope { - public: - Node* begin = nullptr; - Node* end = nullptr; - ShapeInfo scope_shape; - unordered_set boundary_nodes; - - static bool IsBoundary(const Node* input, const Node* output, ArgType arg_type, bool is_identity); - - KernelScope() : begin(nullptr), end(nullptr) {} - KernelScope(Node* node, unordered_set& output_scopes); - - KernelScope(Node* begin, Node* end, ShapeInfo shape, unordered_set boundary_nodes) - : begin(begin), end(end), scope_shape(shape), boundary_nodes(boundary_nodes) {} - - void CopyProperties(KernelScope* other) { - begin = other->begin; - end = other->end; - scope_shape = other->scope_shape; - boundary_nodes = other->boundary_nodes; - } - - bool IsValid() const; - - static pair, bool> ComputeScopes(Node *root); - - static KernelScope* Merge(KernelScope* a, KernelScope* b); - - void CreateKernel(); - - void AddBoundaryNodes(unordered_set new_boundary_nodes) { - boundary_nodes.insert(new_boundary_nodes.begin(), new_boundary_nodes.end()); - } -}; - -} // namespace TensorFrost diff --git a/TensorFrost/Compiler/Implementations.cpp b/TensorFrost/Compiler/Implementations.cpp deleted file mode 100644 index a42acdbe..00000000 --- a/TensorFrost/Compiler/Implementations.cpp +++ /dev/null @@ -1,783 +0,0 @@ -#include "Compiler/Implementations.h" - -namespace TensorFrost { - -const Tensor& ReduceGradientToShape(const Tensor& gradient, const Tensor& target) -{ - ShapeCompareResult shape_result = CompareShape(gradient.node_, target.node_); - if (!shape_result.compatible) { - throw std::runtime_error("Autodiff: gradient shape not compatible with target tensor"); - } - - if(!shape_result.broadcast) { - return gradient; - } - - int dim = shape_result.broadcast_dim; - ShapeInfo gradinfo = gradient.GetShapeInfo(); - ShapeInfo targetinfo = target.GetShapeInfo(); - - gradinfo.ExpandDimensionsTo(dim); - targetinfo.ExpandDimensionsTo(dim); - - vector axes_to_reduce; - vector unsqueeze; - for(int i = 0; i < dim; i++) { - int val_a = gradinfo.GetTensor(i)->TryGetConstant(); - int val_b = targetinfo.GetTensor(i)->TryGetConstant(); - bool b_expanded = targetinfo.IsExpanded(i); - if(b_expanded || (val_a != val_b && val_b == 1)) { - axes_to_reduce.push_back(i); - bool should_unsqueeze = i < target.GetDimension(); - unsqueeze.push_back(should_unsqueeze); - } - } - - Tensor* reduced = const_cast(&gradient); - //go in inverse order to keep the dimensions in the same order - for(int i = (int)axes_to_reduce.size() - 1; i >= 0; i--) { - reduced = &Tensor::Sum(*reduced, axes_to_reduce[i]); - if(unsqueeze[i]) { - reduced = &Tensor::Unsqueeze(*reduced, axes_to_reduce[i]); - } - } - -#ifndef NDEBUG - //check if the reduced shape is the same as the target shape - ShapeCompareResult result = CompareShape(reduced->node_, target.node_); - if(!result.compatible) { - throw std::runtime_error("Gradient shape not compatible with target tensor, function ReduceGradientToShape failed"); - } -#endif - - return *reduced; -} - -Tensor* ConstantOutOfBounds(const Tensor* array, Tensors indices, uint constant) { - ShapeInfo shapeinfo = array->GetShapeInfo(); - int dims = shapeinfo.dim; - Tensors shape = shapeinfo.GetTensors(); - Tensor* is_out_of_bounds = &Tensor::Constant(0, TFTypeBool32); - for (int i = 0; i < dims; i++) { - Tensor* is_out = &(*indices[i] < Tensor::Constant(0) || *indices[i] >= *shape[i]); - is_out_of_bounds = &(*is_out_of_bounds || *is_out); - } - is_out_of_bounds->SetDebugName("out_of_bounds"); - Tensor* value = &Tensor::Constant(constant, array->node_->format); - Tensor* loaded = &Tensor::Load(*array, indices); - return &Tensor::select(*is_out_of_bounds, *value, *loaded); -} - - -Tensor* IsOutOfBounds(const Tensor* array, Tensors indices) { - ShapeInfo shapeinfo = array->GetShapeInfo(); - int dims = shapeinfo.dim; - Tensors shape = shapeinfo.GetTensors(); - Tensor* is_out_of_bounds = &Tensor::Constant(0, TFTypeBool32); - for (int i = 0; i < dims; i++) { - Tensor* is_out = &(*indices[i] < Tensor::Constant(0) || *indices[i] >= *shape[i]); - is_out_of_bounds = &(*is_out_of_bounds || *is_out); - } - is_out_of_bounds->SetDebugName("out_of_bounds"); - return is_out_of_bounds; -} - - -map gradient_functions = -{ - //elementwise operations - {"copy", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad); }}, - {"add", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad, grad); }}, - {"sub", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad, -grad); }}, - {"mul", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad * in[1], grad * in[0]); }}, - {"div", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad / in[1], -grad * in[0] / (in[1] * in[1])); }}, - {"neg", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(-grad); }}, - {"exp", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad * out); }}, - {"log", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad / in[0]); }}, - {"sin", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad * Tensor::cos(in[0])); }}, - {"cos", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(-grad * Tensor::sin(in[0])); }}, - {"tan", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad * (Tensor::Constant(1.0f) + out * out)); }}, - {"asin", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad / Tensor::sqrt(Tensor::Constant(1.0f) - out * out)); }}, - {"acos", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(-grad / Tensor::sqrt(Tensor::Constant(1.0f) - out * out)); }}, - {"atan", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad / (Tensor::Constant(1.0f) + out * out)); }}, - {"abs", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad * Tensor::sign(in[0])); }}, - {"sign", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(Tensor::Constant(0.0f)); }}, - {"exp2", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad * Tensor::Constant(log(2.0f)) * out); }}, - {"log2", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad / (in[0] * Tensor::Constant(log(2.0f)))); }}, - {"sqrt", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad / (Tensor::Constant(2.0f) * out)); }}, - {"rsqrt", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(-grad / (Tensor::Constant(2.0f) * in[0] * out)); }}, - {"floor", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(Tensor::Constant(0.0f)); }}, - {"ceil", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(Tensor::Constant(0.0f)); }}, - {"round", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(Tensor::Constant(0.0f)); }}, - {"frac", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad); }}, - {"atan2", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad * in[1] / (in[0] * in[0] + in[1] * in[1]), -grad * in[0] / (in[0] * in[0] + in[1] * in[1])); }}, - {"lerp", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad * in[2], grad * (Tensor::Constant(1.0f) - in[2]), grad * (in[0] - in[1])); }}, - {"max", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(Tensor::select(in[0] > in[1], grad, Tensor::Constant(0.0f)), Tensor::select(in[0] < in[1], grad, Tensor::Constant(0.0f))); }}, - {"min", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(Tensor::select(in[0] < in[1], grad, Tensor::Constant(0.0f)), Tensor::select(in[0] > in[1], grad, Tensor::Constant(0.0f))); }}, - {"pow", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad * in[1] * Tensor::pow(in[0], in[1] - Tensor::Constant(1.0f)), grad * Tensor::log(in[0]) * out); }}, - {"tanh", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad * (Tensor::Constant(1.0f) - out * out)); }}, - {"clamp", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - //clamp = min(max(x, min), max) - Tensor& dc_dx = Tensor::select((in[0] < in[1]) || (in[0] > in[2]), Tensor::Constant(0.0f), grad); - Tensor& dc_dmin = Tensor::select(in[0] < in[1], grad, Tensor::Constant(0.0f)); - Tensor& dc_dmax = Tensor::select(in[0] > in[2], grad, Tensor::Constant(0.0f)); - grads.Add(dc_dx, dc_dmin, dc_dmax); - }}, - {"ternary", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(Tensor::Constant(0.0f), Tensor::select(in[0], grad, Tensor::Constant(0.0f)), Tensor::select(in[0], Tensor::Constant(0.0f), grad)); }}, - {"lerp", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad * in[2], grad * (Tensor::Constant(1.0f) - in[2]), grad * (in[0] - in[1])); }}, - {"step", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(Tensor::Constant(0.0f), Tensor::Constant(0.0f)); }}, - {"modf", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(grad, Tensor::Constant(0.0f)); }}, - {"fma", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { grads.Add(in[1] * grad, in[0] * grad, grad); }}, - - //matrix operations - {"matmul", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - grads.Add(Tensor::Matmul(grad, Tensor::Transpose(in[1])), Tensor::Matmul(Tensor::Transpose(in[0]), grad)); - }}, - {"transpose", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - grads.Add(Tensor::Transpose(grad, out.axis(1), out.axis(0))); - }}, - {"dot", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - Tensor& unsq_grad = Tensor::Unsqueeze(grad, out.axis()); - grads.Add(unsq_grad * in[1], unsq_grad * in[0]); - }}, - {"unsqueeze", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - grads.Add(Tensor::Squeeze(grad, out.axis())); - }}, - {"squeeze", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - grads.Add(Tensor::Unsqueeze(grad, out.axis())); - }}, - {"dim_sum", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - grads.Add(Tensor::Unsqueeze(grad, out.axis())); - }}, - {"dim_max", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - auto& out_unsq = Tensor::Unsqueeze(out, out.axis()); - auto& grad_unsq = Tensor::Unsqueeze(grad, out.axis()); - grads.Add(Tensor::select(in[0] == out_unsq, grad_unsq, Tensor::Constant(0.0f))); - }}, - {"dim_min", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - auto& out_unsq = Tensor::Unsqueeze(out, out.axis()); - auto& grad_unsq = Tensor::Unsqueeze(grad, out.axis()); - grads.Add(Tensor::select(in[0] == out_unsq, grad_unsq, Tensor::Constant(0.0f))); - }}, - {"dim_prefix_sum", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - //b_i = a_0 + ... + a_i - //db_i/da_j = 1 if i >= j, 0 otherwise - //dL/da_j = sum_i dL/db_i * db_i/da_j - //dL/da_j = sum_i dL/db_i * (i >= j) - //g_i == dL/db_i - //dL/da_j = g_j + g_{j+1} + ... + g_n = g_n + g_{n-1} + ... + g_j - //c_i == g_{n-i} - //dL/da_j = c_0 + c_1 + ... + c_j = prefix_sum(c)_j - grads.Add(Tensor::PrefixSum(Tensor::Reverse(grad, out.axis()), out.axis())); - }}, - {"dim_reverse", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - grads.Add(Tensor::Reverse(grad, out.axis())); - }}, - {"reshape", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - const Tensor* memory_input = in.GetTensor(ArgType::Memory); - grads.Add(ArgType::Memory, 0, Tensor::Reshape(grad, memory_input->GetShape())); - }}, - {"assert", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - const Tensor* memory_input = in.GetTensor(ArgType::Memory); - grads.Add(ArgType::Memory, 0, Tensor::Assert(grad, memory_input->GetShape(), memory_input->GetFormat())); - }}, - //memory operations - {"load", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - //derivative of load is scatter gradient to the load memory addresses - int index_count = in.Count(ArgType::Index); - - Tensors tensor_indices = in.GetTensorVector(ArgType::Index); - const Tensor& curGrad = *grads.GetGrad(ArgType::Memory, 0); - const Tensor& is_out_of_bounds = *IsOutOfBounds(in.GetTensor(ArgType::Memory), tensor_indices); - const Tensor& grad_out_of_bounds = Tensor::select(is_out_of_bounds, Tensor::Constant(0.0f), grad); - Tensor::ScatterAdd(curGrad, grad_out_of_bounds, tensor_indices, out.node_->indexing_mode_); - }}, - {"store", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - //derivative of store is load gradient at the store memory addresses - const Tensor* memory_input = in.GetTensor(ArgType::Memory); - int index_count = in.Count(ArgType::Index); - - Tensors tensor_indices = in.GetTensorVector(ArgType::Index); - const Tensor& memory_grad = *grads.GetGrad(ArgType::Memory, 0); - grads.Add(ArgType::Input, 0, Tensor::Load(memory_grad, tensor_indices, out.node_->indexing_mode_)); - }}, - {"InterlockedAdd", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - //derivative of scatter_add is load gradient at the scatter memory addresses - const Tensor* memory_input = in.GetTensor(ArgType::Memory); - int index_count = in.Count(ArgType::Index); - - Tensors tensor_indices = in.GetTensorVector(ArgType::Index); - const Tensor& memory_grad = *grads.GetGrad(ArgType::Memory, 0); - grads.Add(ArgType::Input, 0, Tensor::Load(memory_grad, tensor_indices, out.node_->indexing_mode_)); - }}, - {"set", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - //derivative of set is the gradient of the setted value to the input - const Tensor& memory_grad = *grads.GetGrad(ArgType::Memory, 0); - grads.Add(ArgType::Input, 0, memory_grad); - }}, - {"detached_grad", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - }}, - {"passthrough_grad", [](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - grads.Add(grad); - }}, -}; - -VJPGradientFunction GetVJPForOperation(string name) { - if (!gradient_functions.contains(name)) { - throw std::runtime_error("Cannot compute VJP for operation " + name); - } - return gradient_functions[name]; -} - -void RegisterVJP(string name, VJPGradientFunction vjp) { - if (gradient_functions.contains(name)) { - throw std::runtime_error("VJP for operation " + name + " already registered"); - } - gradient_functions[name] = vjp; -} - -bool HasDerivativeImplemented(string name) { - return gradient_functions.contains(name); -} - -Tensor* ComputeReduction(const Tensor* array, int axis, - std::function reduction_op, string debug_name = "", - uint initial = 0, - std::function element_op = nullptr) { - // Get shape of the array - Tensors shape = array->GetShape(); - - axis = GetAxis((int)shape.size(), axis); - - // Get the number of dimensions - int dims = (int)shape.size(); - - Tensors sum_shape = Tensors(); - for (int i = 0; i < dims; i++) { - if (i == axis) { - continue; - } - sum_shape.push_back(shape[i]); - } - - // get indices for all dimensions but the last - Tensors indices = Tensors(); - for (int i = 0; i < dims - 1; i++) { - indices.push_back(&Tensor::Index(sum_shape, i)); - } - - // if no dimensions, then add constant 1 - if (sum_shape.empty()) { - sum_shape.push_back(&Tensor::Constant(1)); - } - - Tensors load_index = Tensors(); - for (int id = 0, d = 0; d < dims; d++) { - if (d == axis) { - load_index.push_back(&Tensor::Constant(sum_shape, 0)); - } else { - load_index.push_back(indices[id++]); - } - } - - // start with the first value - Tensor* reduced = &Tensor::Constant(sum_shape, initial, array->node_->format); - reduced->SetDebugName(debug_name); - - // create a loop over the last dimension starting from the second value - Tensor::Loop(Tensor::Constant(0), *shape[axis], Tensor::Constant(1), - [&](const Tensor& i) { - load_index[axis] = &i; - - // load the value - Tensor* value = &Tensor::Load(*array, load_index, IndexingMode::Unsafe); - - if (element_op != nullptr) { - value = element_op(value); - } - - reduced->Set(*reduction_op(reduced, value)); - }); - - return reduced; -} - -Tensor* ComputeScan(const Tensor* array, int axis, std::function scan_op, string debug_name = "", uint initial = 0) { - // Get shape of the array - Tensors shape = array->GetShape(); - - Tensor* scan_result = &Tensor::Memory(shape, array->node_->format); - - axis = GetAxis((int)shape.size(), axis); - - // Get the number of dimensions - int dims = (int)shape.size(); - - Tensors sum_shape = Tensors(); - for (int i = 0; i < dims; i++) { - if (i == axis) { - continue; - } - sum_shape.push_back(shape[i]); - } - - // get indices for all dimensions but the last - Tensors indices = Tensors(); - for (int i = 0; i < dims - 1; i++) { - indices.push_back(&Tensor::Index(sum_shape, i)); - } - - // if no dimensions, then add constant 1 - if (sum_shape.empty()) { - sum_shape.push_back(&Tensor::Constant(1)); - } - - Tensors load_index = Tensors(); - for (int id = 0, d = 0; d < dims; d++) { - if (d == axis) { - load_index.push_back(&Tensor::Constant(sum_shape, 0)); - } else { - load_index.push_back(indices[id++]); - } - } - - // start with the first value - Tensor* reduced = &Tensor::Constant(sum_shape, initial, array->node_->format); - reduced->SetDebugName(debug_name); - - // create a loop over the last dimension starting from the second value - Tensor::Loop(Tensor::Constant(0), *shape[axis], Tensor::Constant(1), - [&](const Tensor& i) { - load_index[axis] = &i; - // load the value - Tensor* value = &Tensor::Load(*array, load_index, IndexingMode::Unsafe); - reduced->Set(*scan_op(reduced, value)); - Tensor::Store(*scan_result, *reduced, load_index, IndexingMode::Unsafe); - }); - - return scan_result; -} - -Tensor* ComputeSum(const Tensor* array, int axis) { - return ComputeReduction(array, axis, [](Tensor* a, Tensor* b) { - return &(*a + *b); }, "sum"); -} - -Tensor* ComputeNorm(const Tensor* array, int axis) { - return &Tensor::sqrt(Tensor::Sum(*array * *array, axis)); -} - -Tensor* ComputeMean(const Tensor* array, int axis) { - return &(Tensor::Sum(*array, axis) / Tensor::tofloat(*array->GetShape()[axis])); -} - -uint GetInitialMax(TFType type) { - if (type == TFType::Float) { - float init = -FLT_MAX; - return *(uint*)&init; - } - else if (type == TFType::Int) { - int init = INT_MIN; - return *(uint*)&init; - } - return 0; -} - -uint GetInitialMin(TFType type) { - if (type == TFType::Float) { - float init = FLT_MAX; - return *(uint*)&init; - } - else if (type == TFType::Int) { - int init = INT_MAX; - return *(uint*)&init; - } - return 0; -} - -Tensor* ComputeMax(const Tensor* array, int axis) { - uint initial = 0; - if (array->node_->format == TFTypeFloat32) { - float init = -FLT_MAX; - initial = *(uint*)&init; - } - else if (array->node_->format == TFTypeInt32) { - int init = INT_MIN; - initial = *(uint*)&init; - } - return ComputeReduction( - array, axis, [](Tensor* a, Tensor* b) { return &Tensor::max(*a, *b); }, - "max", initial); -} - -Tensor* ComputeMin(const Tensor* array, int axis) { - uint initial = UINT_MAX; - if (array->node_->format == TFTypeFloat32) { - float init = FLT_MAX; - initial = *(uint*)&init; - } - else if (array->node_->format == TFTypeInt32) { - int init = INT_MAX; - initial = *(uint*)&init; - } - return ComputeReduction( - array, axis, [](Tensor* a, Tensor* b) { return &Tensor::min(*a, *b); }, - "min", initial); -} - -Tensor* ComputeProduct(const Tensor* array, int axis) { - uint initial = 1; - if (array->node_->format == TFTypeInt32) { - float init = 1.0f; - initial = *(uint*)&init; - } - return ComputeReduction( - array, axis, [](Tensor* a, Tensor* b) { return &(*a * *b); }, "prod", initial); -} - -Tensor* ComputeAny(const Tensor* array, int axis) { - return ComputeReduction(array, axis, [](Tensor* a, Tensor* b) { return &(*a || *b); }, "any", 0); -} - -Tensor* ComputeAll(const Tensor* array, int axis) { - return ComputeReduction( - array, axis, [](Tensor* a, Tensor* b) { return &(*a && *b); }, "all", ~0); -} - -Tensor* ComputePrefixSum(const Tensor* array, int axis) { - return ComputeScan(array, axis, [](Tensor* a, Tensor* b) { return &(*a + *b); }, "prefix_sum"); -} - -Tensor* Transpose(const Tensor* array, map permutation) { - ShapeInfo shapeinfo = array->GetShapeInfo(); - int old_dim = shapeinfo.dim; - Tensors perm_shape = Tensors(); - int permuted_dim = (int)permutation.size(); - - shapeinfo.ExpandDimensionsTo(permuted_dim); - Tensors shape = shapeinfo.GetTensors(); - - for (int i = 0; i < permuted_dim; i++) { - perm_shape.push_back(shape[permutation[i]]); - } - - //create indices - Tensors indices = Tensors(); - for (int i = 0; i < permuted_dim; i++) { - indices.push_back(&Tensor::Index(perm_shape, i)); - } - //permute indices to load the values - Tensors perm_indices = Tensors(old_dim, nullptr); - for (int i = 0; i < permuted_dim; i++) { - int old = permutation[i]; - if(old < old_dim) { - perm_indices[old] = indices[i]; - } - } - //if any nullptr, then put a constant 0 - for (int i = 0; i < old_dim; i++) { - if(perm_indices[i] == nullptr) { - perm_indices[i] = &Tensor::Constant(0); - } - } - - Tensor& loaded = Tensor::Load(*array, perm_indices, IndexingMode::Unsafe); - loaded.SetShape(perm_shape); //in case perm indices has no shape info (0 dim unsqueeze) - loaded.SetDebugName("transposed"); - return &loaded; -} - -Tensor* ReverseDim(const Tensor* array, int axis) { - ShapeInfo shapeinfo = array->GetShapeInfo(); - int dims = shapeinfo.dim; - Tensors shape = shapeinfo.GetTensors(); - Tensors indices = Tensors(); - for (int i = 0; i < dims; i++) { - if (i == axis) { - indices.push_back(&(*shape[i] - Tensor::Constant(1) - Tensor::Index(shape, i))); - } else { - indices.push_back(&Tensor::Index(shape, i)); - } - } - Tensor& loaded = Tensor::Load(*array, indices, IndexingMode::Unsafe); - loaded.SetDebugName("reversed"); - return &loaded; -} - -Tensor* SplitDim(const Tensor* array, const Tensor* splitted, int axis, int split_size) { - ShapeInfo shapeinfo = array->GetShapeInfo(); - int dims = shapeinfo.dim; - Tensors new_shape = splitted->GetShape(); - Tensors indices = Tensors(); - for (int i = 0; i < dims; i++) { - if (i == axis) { - Tensor* index1 = &Tensor::Index(new_shape, i); - Tensor* index2 = &Tensor::Index(new_shape, i + 1); - //merged index - indices.push_back(&(*index2 * (*new_shape[i]) + *index1)); - } else if(i < axis) { - indices.push_back(&Tensor::Index(new_shape, i)); - } else { - indices.push_back(&Tensor::Index(new_shape, i + 1)); - } - } - Tensor* loaded = ConstantOutOfBounds(array, indices, 0); - loaded->SetDebugName("split"); - return loaded; -} - -Tensor* MergeDim(const Tensor* array, const Tensor* merged, int axis) { - ShapeInfo shapeinfo = array->GetShapeInfo(); - int dims = shapeinfo.dim; - axis = GetAxis(dims, axis); - Tensors shape = shapeinfo.GetTensors(); - Tensors new_shape = merged->GetShape(); - Tensors indices = Tensors(); - for (int i = 0; i < dims-1; i++) { - if (i == axis) { - Tensor* merged_index = &Tensor::Index(new_shape, i); - //get split index - indices.push_back(&(*merged_index % *shape[axis])); - indices.push_back(&(*merged_index / *shape[axis])); - } else { - indices.push_back(&Tensor::Index(new_shape, i)); - } - } - Tensor* loaded = ConstantOutOfBounds(array, indices, 0); - loaded->SetDebugName("merged"); - return loaded; -} - -Tensor* ComputeDot(const Tensor* a, const Tensor* b, int axis) { - Tensors shape_a = a->GetShape(); - Tensors shape_b = b->GetShape(); - axis = GetAxis((int)shape_a.size(), axis); - return ComputeSum(&(*a * *b), axis); -} - -//compute the matrix multiplication of two last dimensions -//takes two tensors [T1, T2, ..., Tn, M, N] and [Tm, .., Tn, N, K] and returns [T1, T2, ..., Tm, M, K] -Tensor* ComputeMatMul(const Tensor* a, const Tensor* b) { - ShapeInfo shape_a = a->GetShapeInfo(); - ShapeInfo shape_b = b->GetShapeInfo(); - - if (shape_a.dim < 2 && shape_b.dim < 2) { - throw std::runtime_error("Matrix multiplication requires at least one 2D tensor"); - } - - if(shape_a.dim < 2) { - shape_a.ExpandDimensionsTo(2); - } - if(shape_b.dim < 2) { - shape_b.ExpandDimensionsTo(2); - } - - Tensors shape_a_tensors = shape_a.GetTensors(); - Tensors shape_b_tensors = shape_b.GetTensors(); - - //get shape of the result - Tensors shape_c = Tensors(); - int dim_a = shape_a.dim; - int dim_b = shape_b.dim; - int max_dim = 0; - Tensors max_shape = Tensors(); - //get the shape with most dimensions - if (dim_a < dim_b) { - max_dim = dim_b; - max_shape = shape_b_tensors; - } else { - max_dim = dim_a; - max_shape = shape_a_tensors; - } - - shape_c.push_back(shape_b_tensors[0]); - shape_c.push_back(shape_a_tensors[1]); - for (int i = 2; i < max_dim; i++) { - shape_c.push_back(max_shape[i]); - } - ShapeDimCompareResult result = CompareShapeDim(shape_a_tensors[0]->node_, shape_b_tensors[1]->node_); - if (!result.compatible) { - throw std::runtime_error("Inner dimensions of the matrices must match"); - } - - const Tensor* sum_shape = result.broadcast_dim->GetTensor(); - - // get indices for c elements - Tensors indices_c = Tensors(); - for (int i = 0; i < max_dim; i++) { - indices_c.push_back(&Tensor::Index(shape_c, i)); - } - - // start with 0 - Tensor* c = &Tensor::Constant(shape_c, 0, a->node_->format); - c->SetDebugName("matmul"); - - // loop over k and compute += A t1t2..tN ik * B t1t2..tN kj - Tensor::Loop(Tensor::Constant(0), *sum_shape, Tensor::Constant(1), - [&](const Tensor& k) { - - // get indices for a elements - Tensors indices_a = Tensors(); - - indices_a.push_back(&k); - indices_a.push_back(indices_c[1]); - for (int i = 2; i < dim_a; i++) { - indices_a.push_back(indices_c[i]); - } - - // get indices for b elements - Tensors indices_b = Tensors(); - - indices_b.push_back(indices_c[0]); - indices_b.push_back(&k); - for (int i = 2; i < dim_b; i++) { - indices_b.push_back(indices_c[i]); - } - - Tensor& a_val = Tensor::Load(*a, indices_a, IndexingMode::Unsafe); a_val.node_->flags.set(NodeProp::NoLoadFusion); // disable load fusion for now - Tensor& b_val = Tensor::Load(*b, indices_b, IndexingMode::Unsafe); a_val.node_->flags.set(NodeProp::NoLoadFusion); - - Tensor* prod = &(a_val * b_val); - - c->Set(*c + *prod); - }); - - return c; -} - -map implementation_functions = -{ - {"dim_sum", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ComputeSum(inputs[0],axes[0])); }}, - {"dim_norm", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ComputeNorm(inputs[0],axes[0])); }}, - {"dim_max", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ComputeMax(inputs[0],axes[0])); }}, - {"dim_min", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ComputeMin(inputs[0],axes[0])); }}, - {"dim_mean", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ComputeMean(inputs[0],axes[0])); }}, - {"dim_product", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ComputeProduct(inputs[0],axes[0])); }}, - {"dim_any", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ComputeAny(inputs[0],axes[0])); }}, - {"dim_all", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ComputeAll(inputs[0],axes[0])); }}, - {"dim_prefix_sum", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ComputePrefixSum(inputs[0],axes[0])); }}, - {"transpose", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { - //get the permutation - int dim = (int)inputs[0]->GetDimension(); - dim = std::max(dim, std::max(axes[0], axes[1]) + 1); - map permutation; - for (int i = 0; i < dim; i++) { - if(i == axes[0]) { - permutation[i] = axes[1]; - } else if(i == axes[1]) { - permutation[i] = axes[0]; - } else { - permutation[i] = i; - } - } - outputs.push_back(Transpose(inputs[0], permutation)); - }}, - {"dot", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ComputeDot(inputs[0], inputs[1], axes[0])); }}, - {"matmul", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ComputeMatMul(inputs[0], inputs[1])); }}, - {"unsqueeze", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { - map permutation; - int dim = (int)inputs[0]->GetDimension()+1; - dim = std::max(dim, axes[0] + 1); - for(int i = 0; i < dim; i++) { - if(i == axes[0]) { - permutation[i] = dim-1; - } else if (i < axes[0]) { - permutation[i] = i; - } else { - permutation[i] = i - 1; - } - } - outputs.push_back(Transpose(inputs[0], permutation)); - }}, - {"squeeze", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { - map permutation; - int dim = (int)inputs[0]->GetDimension() - 1; - for(int i = 0; i < dim; i++) { - if(i < axes[0]) { - permutation[i] = i; - } else { - permutation[i] = i + 1; - } - } - outputs.push_back(Transpose(inputs[0], permutation)); - }}, - {"dim_reverse", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(ReverseDim(inputs[0], axes[0])); }}, - {"dim_split", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(SplitDim(inputs[0], tensor, axes[0], axes[1])); }}, - {"dim_merge", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { outputs.push_back(MergeDim(inputs[0], tensor, axes[0])); }}, - {"dim_repeat", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { - const Tensor* input_tensor = inputs[0]; - Tensors shape = input_tensor->GetShape(); - Tensors new_shape = tensor->GetShape(); - Tensors indices = Tensors(); - for (int i = 0; i < (int)new_shape.size(); i++) { - indices.push_back(&Tensor::Index(new_shape, i)); - } - Tensor* loaded = &Tensor::Load(*input_tensor, indices, IndexingMode::Repeat); - loaded->SetDebugName("repeated"); - outputs.push_back(loaded); - }}, - {"smoothstep", [](Tensors& outputs, map inputs, const Tensor* tensor,vector axes ) { - const Tensor& x = *inputs[2]; - const Tensor& edge0 = *inputs[0]; - const Tensor& edge1 = *inputs[1]; - Tensor& x1 = (x - edge0) / (edge1 - edge0); - Tensor& t = Tensor::clamp(x1, Tensor::Constant(0.0f), Tensor::Constant(1.0f)); - Tensor& result = t * t * (Tensor::Constant(3.0f) - Tensor::Constant(2.0f) * t); - outputs.push_back(&result); - }}, -}; - -ImplementationFunction GetImplementationForOperation(string name) { - if (!implementation_functions.contains(name)) { - throw std::runtime_error("Cannot compute implementation for operation " + name); - } - return implementation_functions[name]; -} - -void RegisterImplementation(string name, ImplementationFunction impl) { - if (implementation_functions.contains(name)) { - throw std::runtime_error("Implementation for operation " + name + " already exists"); - } - implementation_functions[name] = impl; -} - - -map algorithm_vjps = {}; - -AlgorithmVJPGradientFunction GetAlgorithmVJPForOperation(string name) { - if (!algorithm_vjps.contains(name)) { - throw std::runtime_error("Cannot compute VJP for operation " + name); - } - return algorithm_vjps[name]; -} - -void RegisterAlgorithmVJP(string name, AlgorithmVJPGradientFunction vjp) { - if (algorithm_vjps.contains(name)) { - throw std::runtime_error("VJP for operation " + name + " already registered"); - } - algorithm_vjps[name] = vjp; -} - -VJPGradientFunction CreateAlgorithmVJP(const string& name) { - VJPGradientFunction vjp = [name](ArgumentManager& in, const Tensor& out, const Tensor& grad, NodeGrads& grads) { - auto inputs = in.GetTensors(ArgType::Input); - AlgorithmVJPGradientFunction impl = GetAlgorithmVJPForOperation(name); - Tensors grad_tensors = impl(inputs, &grad, &out); - for (int i = 0; i < (int)grad_tensors.size(); i++) { - grads.Add(ArgType::Input, i, *const_cast(grad_tensors[i])); - } - }; - return vjp; -} - -void RegisterAlgorithmicPrimitive(const string& name, vector overloads, ImplementationFunction impl, AlgorithmVJPGradientFunction vjp) { - Operation* newop = new Operation(name, overloads, 0, "", {OpProp::Custom, OpProp::Algorithm}); - RegisterNewOperation(newop); - RegisterImplementation(name, impl); - RegisterAlgorithmVJP(name, vjp); - RegisterVJP(name, CreateAlgorithmVJP(name)); -} - -} // namespace TensorFrost - - diff --git a/TensorFrost/Compiler/Implementations.h b/TensorFrost/Compiler/Implementations.h deleted file mode 100644 index 9f221f7e..00000000 --- a/TensorFrost/Compiler/Implementations.h +++ /dev/null @@ -1,112 +0,0 @@ -#pragma once - -#include "Backend/TensorMemory.h" -#include "Tensor/Tensor.h" - -namespace TensorFrost { - -const Tensor& ReduceGradientToShape(const Tensor& gradient, const Tensor& target); -int GetGradAxis(const Tensor& out, const Tensor& grad); - -class NodeGrads -{ - bool stop_fusion = true; - unordered_map stored_gradients; - unordered_map arguments; - unordered_map argument_inputs; -public: - //get element at index - const Tensor& operator[](ArgID id) { - return *stored_gradients[arguments[id]]; - } - - bool Contains(ArgID id) { - return stored_gradients.contains(arguments[id]); - } - - bool Contains(ArgType type, int index = 0) { - return Contains(ArgID(type, index)); - } - - NodeGrads(Node* node, map input_grads) { - try { - for(auto& [id, input] : node->args.Inputs()) { - if (id.first == ArgType::Index || id.first == ArgType::Shape) { - continue; - } - argument_inputs[id] = input->GetTensor(); - arguments[id] = input; - if(input_grads.contains(input)) { - stored_gradients[input] = &ReduceGradientToShape(*input_grads[input], *input->GetTensor()); - } - } - } catch (const std::exception& e) { - throw std::runtime_error("Error in gradient initialization: " + string(e.what())); - } - } - - void Add(ArgType type, int index, const Tensor& tensor) { - ArgID id = ArgID(type, index); - const Tensor* target = argument_inputs[id]; - try { - Tensor& new_tensor = const_cast(ReduceGradientToShape(tensor, *target)); - if(Contains(type, index)) { - auto& old_tensor = *stored_gradients[arguments[id]]; - if(stop_fusion) old_tensor.StopFusion(); - auto& loaded = old_tensor; //TODO: temporary way to restrict fusion, remove after implementing split - stored_gradients[arguments[id]] = &(loaded + new_tensor); - } else { - stored_gradients[arguments[id]] = &new_tensor; - } - } catch (const std::exception& e) { - throw std::runtime_error("Error in gradient addition: " + string(e.what())); - } - } - - const Tensor* GetGrad(ArgID id) { - if(Contains(id)) { - return stored_gradients[arguments[id]]; - } else { - IR* cur_ir = Tensor::GetEvaluationContext(); - const Tensor* input = argument_inputs[id]; - Tensor* zero_grad = nullptr; - cur_ir->ExecuteExpressionAfter(input->node_, [&]() { - zero_grad = &Tensor::Constant(argument_inputs[id]->GetShape(), 0.0f); - stored_gradients[arguments[id]] = zero_grad; - }); - return zero_grad; - } - } - - const Tensor* GetGrad(ArgType type, int index) { - return GetGrad(ArgID(type, index)); - } - - //add gradients to inputs - template - void Add(const Tensor& arg, Args&... args) { - //by default these are ArgType::Input - vector inputs = vector({ &arg, &args... }); - for (int i = 0; i < inputs.size(); i++) { - Add(ArgType::Input, i, *inputs[i]); - } - } -}; - - -typedef function VJPGradientFunction; -typedef function inputs, const Tensor* gradient, const Tensor* tensor)> AlgorithmVJPGradientFunction; - -VJPGradientFunction GetVJPForOperation(string name); -void RegisterVJP(string name, VJPGradientFunction vjp); -bool HasDerivativeImplemented(string name); -//TODO JVPGradientFunction for forward mode autodiff - -typedef function inputs, const Tensor* tensor, vector axes)> ImplementationFunction; - -ImplementationFunction GetImplementationForOperation(string name); -void RegisterImplementation(string name, ImplementationFunction impl); - -void RegisterAlgorithmicPrimitive(const string& name, vector overloads, ImplementationFunction impl, AlgorithmVJPGradientFunction vjp); - -} diff --git a/TensorFrost/Compiler/KernelGen.cpp b/TensorFrost/Compiler/KernelGen.cpp deleted file mode 100644 index 4c0cd3e1..00000000 --- a/TensorFrost/Compiler/KernelGen.cpp +++ /dev/null @@ -1,74 +0,0 @@ -#include "Compiler/KernelGen.h" - -namespace TensorFrost { - -Program* GenerateProgram(IR* ir) -{ - ir->CompileIR(); - - auto* program = new Program(ir); - - vector kernels = ir->GetNodesOfType("kernel"); - - for (auto kernel : kernels) - { - // get the kernel type - map variables; - map read_write; - set group_memory; - NodeArguments shape = kernel->args.GetArguments(ArgType::Shape); - size_t variable_index = 0; - - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - if(node->name == "group_memory") { - group_memory.insert(*node); - continue; - } - if (node->op->HasAllTypes(OpProp::MemoryOp)) { - //if the memory is inside of this kernel - skip node - if (node->flags.has(NodeProp::LocalMemoryOp)) { - continue; - } - - // get the memory node - const Tensor* memory = node->args.GetTensor(ArgType::Memory); - - if(node->op->HasAllTypes(OpProp::Modifier)) { - read_write[memory->node_] |= true; - } else { - read_write[memory->node_] |= false; - } - } - - // get all input arguments - for (auto [id, from] : node->args.Inputs()) { - if (id.first == ArgType::Input) - { - bool from_outside_kernel = !from->HasParent(kernel); - if (from_outside_kernel && !variables.contains(from)) { - variables[from] = variable_index++; - } - } - } - } - - map read_write_memory; - map read_only_memory; - size_t read_write_index = 0; - size_t read_only_index = 0; - for(auto [node, rw] : read_write) { - if(rw) { - read_write_memory[node] = read_write_index++; - } else { - read_only_memory[node] = read_only_index++; - } - } - - // add the kernel to the program - program->AddKernel(kernel, variables, read_write_memory, read_only_memory, group_memory, shape); - } - - return program; -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Compiler/KernelGen.h b/TensorFrost/Compiler/KernelGen.h deleted file mode 100644 index 30b3c009..00000000 --- a/TensorFrost/Compiler/KernelGen.h +++ /dev/null @@ -1,80 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -#include "Backend/TensorMemory.h" -#include "Tensor/Tensor.h" -#include "Compiler/Implementations.h" - -namespace TensorFrost { - -uint GetInitialMax(TFType type); -uint GetInitialMin(TFType type); - -class Kernel { - public: - Node* root; - map variables; - map read_write_memory; - map read_only_memory; - set group_memory; - NodeArguments shape; - - size_t kernel_id_; - string kernel_name_; - string full_generated_code_; - string generated_header_; - string generated_bindings_; - string generated_main_; - - vector var_names; - vector var_types; - - map GetMemoryBindings() { - map result; - for (auto& mem : read_write_memory) { - result[mem.first] = mem.second; - } - for (auto& mem : read_only_memory) { - result[mem.first] = mem.second + read_write_memory.size(); - } - return result; - } -}; - -class Program { - public: - IR* ir_; - vector kernels_; - function unload_callback; - string generated_code_; - string main_function_; - string program_name = "TensorProgram"; - - float last_execution_time = 0.0f; - float shader_compile_time = 0.0f; - float host_compile_time = 0.0f; - - function execute_callback; - - explicit Program(IR* ir) : ir_(ir) {} - - void AddKernel(Node* kernel_node, map variables, map read_write, map read_only, set group_memory, - NodeArguments shape) - { - kernels_.push_back( - {kernel_node, std::move(variables), std::move(read_write), std::move(read_only), std::move(group_memory), std::move(shape)}); - } -}; - -Program* GenerateProgram(IR* ir); - -bool isConstantAndEqualTo(const Tensor* tensor, float value); -bool isConstant(const Tensor* tensor); -Tensor* ApplyMultiOP(const Tensor* a, const Tensor* b, std::function opF32, std::function opI32, std::function opU32); -Tensor* ApplyUnaryOP(const Tensor* a, std::function opF32, std::function opI32, std::function opU32); - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Compiler/Operations.cpp b/TensorFrost/Compiler/Operations.cpp deleted file mode 100644 index 34091587..00000000 --- a/TensorFrost/Compiler/Operations.cpp +++ /dev/null @@ -1,226 +0,0 @@ -#include "Operations.h" -#include - -namespace TensorFrost { - -unordered_map type_names = { - {TFType::None, "void"}, {TFType::Bool, "bool"}, {TFType::Float, "float"}, - {TFType::Uint, "uint"}, {TFType::Int, "int"}, -}; - -std::unordered_map DataTypeNames = { - {TFType::Float, "Float"}, {TFType::Uint, "Uint"}, - {TFType::Int, "Int"}, {TFType::Bool, "Bool"}, - {TFType::None, "None"}, -}; - -std::map DataFormatNames = { - {TFTypeFloat32, "TFTypeFloat32"}, {TFTypeInt32, "TFTypeInt32"}, - {TFTypeUint32, "TFTypeUint32"}, {TFTypeBool32, "TFTypeBool32"}, - {TFTypeNone, "TFTypeNone"}, -}; - -const vector operations = { - //Scope operations - Operation("host", {""}, 0, "", {OpProp::Static, OpProp::Special, OpProp::HostOnly, OpProp::Nondiff, OpProp::HasChildren}), - Operation("kernel", {""}, 0, "", {OpProp::Static, OpProp::Special, OpProp::HostOnly, OpProp::Nondiff, OpProp::HasChildren}), - - //Control operations - Operation("loop", {"iii_i"}, 100, "", {OpProp::Static, OpProp::Special, OpProp::Nondiff, OpProp::HasChildren}), - Operation("if", {"b_"}, 100, "", {OpProp::Static, OpProp::Special, OpProp::Nondiff, OpProp::HasChildren}), - Operation("break", {""}, 0, "break", {OpProp::Static, OpProp::Nondiff}, OpClass::Keyword), - Operation("continue", {""}, 0, "continue", {OpProp::Static, OpProp::Nondiff}, OpClass::Keyword), - Operation("discard", {""}, 0, "discard", {OpProp::Static, OpProp::Nondiff}, OpClass::Keyword), //discard current thread - Operation("group_barrier", {""}, 256, "", {OpProp::Static, OpProp::KernelOnly}), - - //Allocation operations - Operation("memory", {"_f", "_i", "_u", "_b"}, 0, "", {OpProp::Memory, OpProp::Special, OpProp::HostOnly, OpProp::Nondiff}), - Operation("reshape", {"_f", "_i", "_u", "_b"}, 0, "", {OpProp::Memory, OpProp::Special, OpProp::HostOnly, OpProp::MemoryReuse}), - Operation("assert", {"_f", "_i", "_u", "_b"}, 0, "assert_tensor", {OpProp::Memory, OpProp::Special, OpProp::HostOnly, OpProp::MemoryReuse}), - Operation("input_shape", {"_i"}, 0, "", {OpProp::Special, OpProp::Static, OpProp::HostOnly, OpProp::Nondiff}), - Operation("deallocate", {""}, 0, "", {OpProp::Memory, OpProp::Special, OpProp::HostOnly, OpProp::Nondiff}), - Operation("group_memory", {"_f", "_i", "_u", "_b"}, 0, "", {OpProp::Memory, OpProp::LocalMemory, OpProp::Special, OpProp::KernelOnly}), - Operation("local_memory", {"_f", "_i", "_u", "_b"}, 0, "", {OpProp::Memory, OpProp::LocalMemory, OpProp::Special, OpProp::KernelOnly}), - - Operation("region_begin", {""}, 0, "", {OpProp::Special, OpProp::Static, OpProp::HostOnly, OpProp::Nondiff, OpProp::Debug}), - Operation("region_end", {""}, 0, "", {OpProp::Special, OpProp::Static, OpProp::HostOnly, OpProp::Nondiff, OpProp::Debug}), - Operation("print_value", {"f_", "i_", "u_", "b_"}, 0, "", {OpProp::Special, OpProp::Static, OpProp::HostOnly, OpProp::Nondiff, OpProp::Debug}), - Operation("assert_value", {"b_"}, 0, "", {OpProp::Special, OpProp::Static, OpProp::HostOnly, OpProp::Nondiff, OpProp::Debug}), - - //Algorithms - //Reduction - Operation("dim_sum", {"f_f", "u_u", "i_i"}, 0, "", {OpProp::Algorithm, OpProp::Reduction}), - Operation("dim_norm", {"f_f", "u_u", "i_i"}, 0, "", {OpProp::Algorithm}), - Operation("dim_max", {"f_f", "u_u", "i_i"}, 0, "", {OpProp::Algorithm, OpProp::Reduction}), - Operation("dim_min", {"f_f", "u_u", "i_i"}, 0, "", {OpProp::Algorithm, OpProp::Reduction}), - Operation("dim_mean", {"f_f", "u_u", "i_i"}, 0, "", {OpProp::Algorithm, OpProp::Reduction}), - Operation("dim_prod", {"f_f", "u_u", "i_i"}, 0, "", {OpProp::Algorithm, OpProp::Reduction}), - Operation("dim_any", {"u_u", "i_i", "b_b"}, 0, "", {OpProp::Algorithm, OpProp::Nondiff, OpProp::Reduction}), - Operation("dim_all", {"u_u", "i_i", "b_b"}, 0, "", {OpProp::Algorithm, OpProp::Nondiff, OpProp::Reduction}), - //Scan - Operation("dim_prefix_sum", {"f_f", "u_u", "i_i"}, 0, "", {OpProp::Algorithm, OpProp::Scan}), - //Matrix - Operation("transpose", {"f_f", "u_u", "i_i", "b_b"}, 0, "", {OpProp::Algorithm}), - Operation("dot", {"ff_f"}, 0, "", {OpProp::Algorithm}), // dot product of the last dimensions - Operation("matmul", {"ff_f"}, 0, "", {OpProp::Algorithm}), // matrix multiplication of the last dimensions - - //Other - Operation("dim_reverse", {"f_f", "u_u", "i_i", "b_b"}, 0, "", {OpProp::Algorithm}), - Operation("dim_repeat", {"f_f", "u_u", "i_i", "b_b"}, 0, "", {OpProp::Algorithm}), - // Operation("dim_concat", {"ff_f", "uu_u", "ii_i", "bb_b"}, 0, "", {OpProperty::Algorithm}), - Operation("dim_split", {"f_f", "u_u", "i_i", "b_b"}, 0, "", {OpProp::Algorithm}), - Operation("dim_merge", {"f_f", "u_u", "i_i", "b_b"}, 0, "", {OpProp::Algorithm}), - // Operation("dim_pad", {"f_f", "u_u", "i_i", "b_b"}, 0, "", {OpProperty::Algorithm}), - Operation("unsqueeze", {"f_f", "u_u", "i_i", "b_b"}, 0, "", {OpProp::Algorithm}), - Operation("squeeze", {"f_f", "u_u", "i_i", "b_b"}, 0, "", {OpProp::Algorithm}), - - Operation("smoothstep", {"fff_f"}, 10, "", {OpProp::Algorithm}), - - //Autodiff - Operation("backwards_grad", {"ff_f"}, 0, "", {OpProp::Gradient}), - Operation("forward_grad", {"ff_f"}, 0, "", {OpProp::Gradient}), - - // Memory operations - Operation("load", {"_f", "_u", "_i", "_b"}, 128, "", - {OpProp::Load, OpProp::MemoryOp}), - Operation("store", {"f_", "u_", "i_", "b_"}, 128, "", - {OpProp::Store, OpProp::MemoryOp, OpProp::Modifier}), - Operation("set", {"f_", "u_", "i_", "b_"}, 1, "", - {OpProp::Set, OpProp::Modifier}), - Operation("InterlockedAdd", {"u_", "i_", "f_"}, 256, "", - {OpProp::Scatter, OpProp::MemoryOp, OpProp::Modifier}), - Operation("InterlockedMin", {"u_", "i_", "f_"}, 256, "", - {OpProp::Scatter, OpProp::MemoryOp, OpProp::Modifier}), - Operation("InterlockedMax", {"u_", "i_", "f_"}, 256, "", - {OpProp::Scatter, OpProp::MemoryOp, OpProp::Modifier}), - Operation("InterlockedAnd", {"u_", "i_"}, 256, "", - {OpProp::Scatter, OpProp::MemoryOp, OpProp::Modifier, OpProp::Nondiff}), - Operation("InterlockedOr", {"u_", "i_"}, 256, "", - {OpProp::Scatter, OpProp::MemoryOp, OpProp::Modifier, OpProp::Nondiff}), - Operation("InterlockedXor", {"u_", "i_"}, 256, "", - {OpProp::Scatter, OpProp::MemoryOp, OpProp::Modifier, OpProp::Nondiff}), - Operation("InterlockedAdd_Prev", {"u_u", "i_i", "f_f"}, 256, "", - {OpProp::Scatter, OpProp::MemoryOp, OpProp::Modifier, OpProp::CantSubstitute, OpProp::Nondiff}), - - // Index operations - Operation("dim_id", {"_i"}, 0, "dim", {OpProp::Nondiff}, OpClass::DimensionIndex), - Operation("block_thread_id", {"_i"}, 0, "", {OpProp::Nondiff}, OpClass::DimensionIndex), - Operation("block_id", {"_i"}, 0, "", {OpProp::Nondiff}, OpClass::Variable), - - //Compute operations - Operation("copy", {"f_f", "u_u", "i_i", "b_b"}, 1, "", {}, OpClass::Copy), //TODO: make sure no one copies memory objects - Operation("add", {"ff_f", "uu_u", "ii_i"}, 1, "+", {}, OpClass::Operator), - Operation("sub", {"ff_f", "uu_u", "ii_i"}, 1, "-", {}, OpClass::Operator), - Operation("mul", {"ff_f", "uu_u", "ii_i"}, 1, "*", {}, OpClass::Operator), - Operation("div", {"ff_f", "uu_u", "ii_i"}, 2, "/", {}, OpClass::Operator), - Operation("mod", {"ff_f", "uu_u", "ii_i"}, 4, "%", {OpProp::Nondiff}, OpClass::Operator), - Operation("lshift", {"uu_u", "ui_u", "ii_i"}, 1, "<<", {OpProp::Nondiff}, OpClass::Operator), - Operation("rshift", {"uu_u", "ui_u", "ii_i"}, 1, ">>", {OpProp::Nondiff}, OpClass::Operator), - Operation("and", {"uu_u", "ii_i", "bb_b"}, 1, "&", {OpProp::Nondiff}, OpClass::Operator), - Operation("or", {"uu_u", "ii_i", "bb_b"}, 1, "|", {OpProp::Nondiff}, OpClass::Operator), - Operation("xor", {"uu_u", "ii_i", "bb_b"}, 1, "^", {OpProp::Nondiff}, OpClass::Operator), - Operation("eq", {"ff_b", "uu_b", "ii_b", "bb_b", "ui_b", "iu_b"}, 1, "==", {OpProp::Nondiff}, OpClass::Operator), - Operation("neq", {"ff_b", "uu_b", "ii_b", "bb_b", "ui_b", "iu_b"}, 1, "!=", {OpProp::Nondiff}, OpClass::Operator), - Operation("lt", {"ff_b", "uu_b", "ii_b", "bb_b", "ui_b", "iu_b"}, 1, "<", {OpProp::Nondiff}, OpClass::Operator), - Operation("lte", {"ff_b", "uu_b", "ii_b", "bb_b", "ui_b", "iu_b"}, 1, "<=", { OpProp::Nondiff}, OpClass::Operator), - Operation("gt", {"ff_b", "uu_b", "ii_b", "bb_b", "ui_b", "iu_b"}, 1, ">", {OpProp::Nondiff}, OpClass::Operator), - Operation("gte", {"ff_b", "uu_b", "ii_b", "bb_b", "ui_b", "iu_b"}, 1, ">=", {OpProp::Nondiff}, OpClass::Operator), - Operation("notb", {"b_b"}, 1, "!", {OpProp::Nondiff}, OpClass::UnaryOperator), - Operation("not", {"u_u", "i_i"}, 1, "~", {OpProp::Nondiff}, OpClass::UnaryOperator), - Operation("neg", {"f_f", "u_u", "i_i"}, 1, "-", {}, OpClass::UnaryOperator), - Operation("uint", {"f_u", "u_u", "i_u", "b_u"}, 1, "uint", {OpProp::Nondiff}, OpClass::TypeCast), - Operation("int", {"f_i", "u_i", "i_i", "b_i"}, 1, "int", {OpProp::Nondiff}, OpClass::TypeCast), - Operation("float", {"f_f", "u_f", "i_f", "b_f"}, 1, "float", {OpProp::Nondiff}, OpClass::TypeCast), - Operation("bool", {"f_b", "u_b", "i_b", "b_b"}, 1, "bool", {OpProp::Nondiff}, OpClass::TypeCast), - Operation("asuint", {"f_u", "u_u", "i_u", "b_u"}, 0, "asuint", - {OpProp::Nondiff}, OpClass::TypeReinterpret), - Operation("asint", {"f_i", "u_i", "i_i", "b_i"}, 0, "asint", - {OpProp::Nondiff}, OpClass::TypeReinterpret), - Operation("asfloat", {"f_f", "u_f", "i_f", "b_f"}, 0, "asfloat", - {OpProp::Nondiff}, OpClass::TypeReinterpret), - Operation("asbool", {"f_b", "u_b", "i_b", "b_b"}, 0, "asbool", - {OpProp::Nondiff}, OpClass::TypeReinterpret), - Operation("min", {"ff_f", "uu_u", "ii_i"}, 1), - Operation("max", {"ff_f", "uu_u", "ii_i"}, 1), - Operation("abs", {"f_f", "u_u", "i_i"}, 1), - Operation("sign", {"f_f", "i_i"}, 1), - Operation("ceil", {"f_f"}, 1), - Operation("floor", {"f_f"}, 1), - Operation("round", {"f_f"}, 1), - Operation("frac", {"f_f"}, 1), - Operation("exp", {"f_f"}, 32), - Operation("exp2", {"f_f"}, 16), - Operation("log", {"f_f"}, 32), - Operation("log2", {"f_f"}, 16), - Operation("sqrt", {"f_f"}, 4), - Operation("rsqrt", {"f_f"}, 2), - Operation("rcp", {"f_f"}, 2), - Operation("sin", {"f_f"}, 2), - Operation("cos", {"f_f"}, 2), - Operation("tan", {"f_f"}, 2), - Operation("asin", {"f_f"}, 8), - Operation("acos", {"f_f"}, 8), - Operation("atan", {"f_f"}, 8), - Operation("sinh", {"f_f"}, 8), - Operation("cosh", {"f_f"}, 8), - Operation("tanh", {"f_f"}, 8), - Operation("pcg", {"u_u"}, 32), - Operation("reversebits", {"i_i", "u_u"}, 8), - Operation("pcgf", {"u_f"}, 34, "", {OpProp::Nondiff}), - Operation("pow", {"ff_f"}, 6), - Operation("atan2", {"ff_f"}, 32), - Operation("modf", {"ff_f"}, 2), - Operation("step", {"ff_f"}, 2), - Operation("clamp", {"fff_f", "uuu_u", "iii_i"}, 4), - Operation("lerp", {"fff_f"}, 4), - Operation("fma", {"fff_f"}, 1), - Operation("ternary", {"bff_f", "buu_u", "bii_i", "bbb_b"}, 4, "", {}, OpClass::TernaryOperator), - Operation("const", {"_f", "_u", "_i", "_b"}, 0, "", {OpProp::Nondiff}, OpClass::Constant), -}; - -void RegisterNewOperation(unordered_map& operation_map, const Operation* op) { - if (operation_map.contains(op->name_)) { - throw runtime_error("Operation already exists: " + op->name_); - } - operation_map[op->name_] = op; -} - -unordered_map CreateOperationMap() { - unordered_map operation_map; - for (const auto& op : operations) { - RegisterNewOperation(operation_map, &op); - } - return operation_map; -} - -unordered_map operation_map = CreateOperationMap(); - -void RegisterNewOperation(const Operation* op) { - RegisterNewOperation(operation_map, op); -} - -DataTypeList Types(initializer_list elements) { - return DataTypeList(elements); -} - -const Operation* FindOperation(const string& name) { - if (name == "") { - throw runtime_error("Operation name is empty"); - } - - auto it = operation_map.find(name); - if (it != operation_map.end()) { - return it->second; - } - - throw runtime_error("IR Operation not defined: " + name); -} - -string DataTypeToString(TFType type) { return type_names[type]; } - -string RemoveSpaces(string str) { - str.erase(remove(str.begin(), str.end(), ' '), str.end()); - return str; -} - -} // namespace TensorFrost diff --git a/TensorFrost/Compiler/Operations.h b/TensorFrost/Compiler/Operations.h deleted file mode 100644 index a6fd6f1f..00000000 --- a/TensorFrost/Compiler/Operations.h +++ /dev/null @@ -1,284 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "Utility/Utility.h" - -namespace TensorFrost { - -using namespace std; - -extern "C" { - enum TFType { - Float, - Uint, - Int, - Bool, - None, - }; - - struct TFDataFormat { - TFType type; - size_t size; - - bool operator==(const TFDataFormat& other) const { - return type == other.type && size == other.size; - } - - bool operator!=(const TFDataFormat& other) const { - return !(*this == other); - } - - int GetHash() const { - return (int)type << 16 | (int)size; - } - - bool operator<(const TFDataFormat& other) const { - return GetHash() < other.GetHash(); - } - - bool operator>(const TFDataFormat& other) const { - return GetHash() > other.GetHash(); - } - }; - -#define TFTypeNone TFDataFormat{TFType::None, 0} -#define TFTypeBool32 TFDataFormat{TFType::Bool, 32} -#define TFTypeFloat32 TFDataFormat{TFType::Float, 32} -#define TFTypeInt32 TFDataFormat{TFType::Int, 32} -#define TFTypeUint32 TFDataFormat{TFType::Uint, 32} -} - -extern std::unordered_map DataTypeNames; -extern std::map DataFormatNames; -extern std::unordered_map type_names; - -//op can have only one class -enum class OpClass { - Operator, - UnaryOperator, - Function, - Copy, - Keyword, - DimensionIndex, - Variable, - TypeCast, - TypeReinterpret, - Constant, - TernaryOperator, - None, -}; - -//op can have multiple properties -enum class OpProp { - Load, - Store, - Set, - Scatter, - Special, - Memory, - LocalMemory, - CantSubstitute, - MemoryOp, - LocalMemoryOp, - Static, //can not be removed or copied - HostOnly, - KernelOnly, - Composite, - Algorithm, - Custom, - Reduction, - Scan, - Modifier, - MemoryReuse, - Gradient, - Nondiff, - HasChildren, - Debug, - Count, -}; - -using OpProps = FlagSet; - -using DataTypeList = vector; - -DataTypeList Types(initializer_list elements); - -class Operation { -public: - string name_; - float cost_ = 0.0F; - vector, TFType>> overloads_; - string code_; - //vector op_classes; - OpProps props_; - OpClass class_; - size_t default_size = 32; - - Operation() = default; - - Operation(string name, vector overloads, float cost, - string code = "", initializer_list op_props = {}, OpClass op_class = OpClass::Function) - : name_(std::move(name)){ - if (code.empty()) { - code = name_; - } - - code_ = code; - cost_ = cost; - - //add op types - for (const auto& type : op_props) { - props_.set(type); - } - - class_ = op_class; - - // parse the overloads - // example: "ff_f" means two floats in, one float out, "buf_f" means a bool, - // uint, float in, float out - for (const auto& oload : overloads) { - vector inputs; - TFType output = TFType::None; - bool is_output = false; - - for (const auto& c : oload) { - TFType parsed_type = TFType::None; - switch (c) { - case 'f': - parsed_type = TFType::Float; - break; - case 'u': - parsed_type = TFType::Uint; - break; - case 'i': - parsed_type = TFType::Int; - break; - case 'b': - parsed_type = TFType::Bool; - break; - case '_': - is_output = true; - break; - default: - throw std::runtime_error("Invalid character in overload string"); - break; - } - - if (is_output) { - output = parsed_type; - } else { - inputs.push_back(parsed_type); - } - } - - overloads_.emplace_back(inputs, output); - } - } - - bool HasAllTypes(OpProp type) const { - return props_.has(type); - } - - template - bool HasAllTypes(OpProp type, Args... args) const { - return HasAllTypes(type) && HasAllTypes(args...); - } - - bool HasAnyType(OpProp type) const { - return HasAllTypes(type); - } - - template - bool HasAnyType(OpProp type, Args... args) const { - return HasAllTypes(type) || HasAnyType(args...); - } - - float GetCost() const { return cost_; } - - string GetName() const { return name_; } - - vector, TFType>> GetOverloads() const { - return overloads_; - } - - size_t GetInputCount() const { - return overloads_[0].first.size(); - } - - bool IsOverloadValid(const pair, TFType>& overload, const vector& input_types) const { - if (overload.first.size() != input_types.size()) { - return false; - } - - for (size_t i = 0; i < input_types.size(); i++) { - if (overload.first[i] != input_types[i].type) { - return false; - } - } - - if (input_types.size() == 0) { - return true; - } - - //check if all format sizes are the same - size_t size = input_types[0].size; - for (size_t i = 1; i < input_types.size(); i++) { - if (input_types[i].size != size) { - return false; - } - } - - return true; - } - - bool IsInputValid(const vector& input_types) const { - for (const auto& overload : overloads_) { - if (IsOverloadValid(overload, input_types)) { - return true; - } - } - return false; - } - - bool IsOutputValid(const TFDataFormat& output_type) const { - for (const auto& overload : overloads_) { - if (overload.second == output_type.type) { - return true; - } - } - return false; - } - - TFDataFormat GetOutputType( - const vector& input_types) const { - for (const auto& overload : overloads_) { - if (IsOverloadValid(overload, input_types)) { - size_t cur_size = default_size; - if (input_types.size() > 0) { - cur_size = input_types[0].size; - } - return {overload.second, cur_size}; - } - } - throw std::runtime_error("Invalid input types for operation"); - } -}; - -const Operation* FindOperation(const string& name); - -string DataTypeToString(TFType type); - -string RemoveSpaces(string str); - -void RegisterNewOperation(const Operation* op); - -} // namespace TensorFrost diff --git a/TensorFrost/Compiler/Steps/Algorithms.cpp b/TensorFrost/Compiler/Steps/Algorithms.cpp deleted file mode 100644 index d8bf8034..00000000 --- a/TensorFrost/Compiler/Steps/Algorithms.cpp +++ /dev/null @@ -1,64 +0,0 @@ -#include "Compiler/KernelGen.h" - -namespace TensorFrost { -bool IR::InsertAlgorithmicPrimitives(bool skip_differentiable) { - // get all nodes for each type - vector nodes = GetNodesOfType(OpProp::Algorithm); - - unordered_set nodes_to_remove; - - // replace all nodes with the algorithmic primitive - for (auto node : nodes) { - if(HasDerivativeImplemented(node->name) && skip_differentiable) { - continue; - } - //compute the sum after the node - ExecuteExpressionAfter(node, [&]() { - //get the input tensor - map inputs = node->args.GetTensors(ArgType::Input); - - //get sum axis - vector axes; - for (int i = 0; i < node->data.size(); i++) { - axes.push_back((int)node->data[i]); - } - - Tensors results; - ImplementationFunction func = GetImplementationForOperation(node->name); - -#ifndef NDEBUG - current_function = node->name; -#endif - - func(results, inputs, node->GetTensor(), axes); - -#ifndef NDEBUG - current_function = "None"; -#endif - - const Tensor* result = results[0]; - - //replace the node with the sum - node->ReplaceThisWithGivenNode(result->node_); - - ShapeCompareResult shape_result = CompareShape(node, result->node_); - if (!shape_result.exactly_compatible) { - throw std::runtime_error("Algorithmic primitive " + node->name + " at " + node->debug_name + " has incompatible shapes"); - } - }); - - //mark the node for removal - nodes_to_remove.insert(node); - } - - // remove all nodes that are not used - for (auto* node : nodes_to_remove) { - RemoveNode(node); - } - - UpdateGraph(); - - return nodes_to_remove.empty(); -} - -} // namespace TensorFrost diff --git a/TensorFrost/Compiler/Steps/Autodiff.cpp b/TensorFrost/Compiler/Steps/Autodiff.cpp deleted file mode 100644 index 773839ed..00000000 --- a/TensorFrost/Compiler/Steps/Autodiff.cpp +++ /dev/null @@ -1,154 +0,0 @@ -#include "Compiler/KernelGen.h" - -namespace TensorFrost { - -void ComputeNodeGradients(Node* value, const Tensor* grad, NodeGrads& grads) -{ - try { - string op_name = value->name; - //add input arguments - if(value->flags.has(NodeProp::PassGrad)) { - op_name = "passthrough_grad"; - } - if(value->flags.has(NodeProp::DetachGrad)) { - op_name = "detached_grad"; - } - - VJPGradientFunction gradient_func = GetVJPForOperation(op_name); - - Tensor out = *value->tensor_; - gradient_func(value->args, out, *grad, grads); - } catch (const std::exception& e) { - throw std::runtime_error("Error in gradient computation for " + value->debug_name + "(" + to_string(value->debug_index) + "): " + e.what()); - } -} - -bool IR::ComputeAutodiff() -{ - vector gradients = GetNodesOfType(OpProp::Gradient); - - if(gradients.empty()) { - return true; - } - - set loss_nodes; - map, Node*> loss_wrt_grad; - unordered_map min_range; //index of earliest node required for the gradient, end of backpropagation - - for (auto gradient : gradients) { - Node* loss = gradient->args.Get(ArgType::Input, 0); - Node* wrt = gradient->args.Get(ArgType::Input, 1); - Node* last_loss_version = loss->GetLastVersion(gradient); - - loss_nodes.insert(last_loss_version); - if(!min_range.contains(last_loss_version)) { - min_range[last_loss_version] = wrt->index_; - } else { - min_range[last_loss_version] = std::min(min_range[last_loss_version], wrt->index_); - } - loss_wrt_grad[{last_loss_version, wrt}] = gradient; - } - - map grad_to_computed_grad; - for (auto loss : loss_nodes) { - set visited; - map node_to_grad; - - unordered_set loss_deps = GetDependencies({loss}); - - //get all differentiable nodes that can change the loss - vector queue; - for (auto dep : loss_deps) { - bool in_range = (dep->index_ <= loss->index_ && dep->index_ >= min_range[loss]); - bool dep_is_accessible = dep->HasCommonParents(loss); //is it in scope of the loss - if(in_range && !dep->op->HasAllTypes(OpProp::Nondiff) && - dep_is_accessible && (dep->format.type == TFType::Float || dep->op->HasAllTypes(OpProp::Modifier))) { - queue.push_back(dep); - } - } - - //sort the nodes by index in descending order (backpropagation) - ranges::sort(queue.begin(), queue.end(), [](Node* a, Node* b) { - return a->index_ > b->index_; - }); - - Node* loss_value = loss; - if(loss->op->HasAllTypes(OpProp::Modifier)) { - loss_value = loss->args.Get(ArgType::Memory); - } - - ExecuteExpressionAfter(loss, [&]() { - node_to_grad[loss_value] = &Tensor::Constant(1.0f); - for(auto node : queue) { - if(!node_to_grad.contains(node) && !node->op->HasAllTypes(OpProp::Modifier)) { - continue; - } - - node_to_grad[node] = &ReduceGradientToShape(*node_to_grad[node], *node->GetTensor()); - - NodeGrads grads = NodeGrads(node, node_to_grad); - - #ifndef NDEBUG - current_function = node->name + "_grad"; - #endif - - ComputeNodeGradients(node, node_to_grad[node], grads); - - #ifndef NDEBUG - current_function = "None"; - #endif - - //store the computed gradients - for (auto& [id, input]: node->args.Inputs()) { - if(!grads.Contains(id)) { - continue; - } - - const Tensor& new_grad = *grads.GetGrad(id); - node_to_grad[input] = &new_grad; - - //TODO: maybe add a function to get temp names - if(input->debug_name != "") { - new_grad.SetDebugName("d" + loss_value->debug_name + "_d" + input->debug_name); - } else if(input->var_name != "") { - new_grad.SetDebugName("d" + loss_value->debug_name + "_d" + input->var_name); - } - } - } - }); - - for (auto wrt_grad : loss_wrt_grad) { - if (wrt_grad.first.first != loss) { - continue; - } - - Node* grad = wrt_grad.second; - if(!node_to_grad.contains(wrt_grad.first.second)) { - throw std::runtime_error("Gradient not computed for " + wrt_grad.first.second->var_name); - } - Node* computed_grad = node_to_grad[wrt_grad.first.second]->node_; - grad_to_computed_grad[grad] = computed_grad; - } - } - - unordered_set nodes_to_remove; - //replace all gradients with computed gradients - for (auto gradient : gradients) { - Node* computed_grad = grad_to_computed_grad[gradient]; - //replace the node with the sum - gradient->ReplaceThisWithGivenNode(computed_grad); - - //mark the node for removal - nodes_to_remove.insert(gradient); - } - - for (auto* node : nodes_to_remove) { - RemoveNode(node); - } - - UpdateGraph(); - - return false; -} - -} // namespace TensorFrost diff --git a/TensorFrost/Compiler/Steps/GraphOps.cpp b/TensorFrost/Compiler/Steps/GraphOps.cpp deleted file mode 100644 index 85aef0ff..00000000 --- a/TensorFrost/Compiler/Steps/GraphOps.cpp +++ /dev/null @@ -1,1343 +0,0 @@ -#include "Compiler/KernelGen.h" -#include "Backend/CodeGen/Generators.h" - -namespace TensorFrost { - -void KernelScope::CreateKernel() { - //create kernel node - Tensor& tensor = Tensor::Kernel(scope_shape.GetTensors()); - Node* kernel_node = tensor.node_; - Node* old_child = kernel_node->child; - kernel_node->child = begin; - kernel_node->next = end->next; - begin->parent = kernel_node; - begin->prev = nullptr; - end->next->prev = kernel_node; - end->next = old_child; - old_child->prev = end; -} - -void IR::SeparateOperationsIntoKernels() { - - unordered_set kernel_scopes; - - ExecuteExpressionFirstChild(root, [&]() { - kernel_scopes = KernelScope::ComputeScopes(root).first; - }); - - // create kernel nodes for all kernel scopes - for (auto scope : kernel_scopes) { - // create kernel before the scope - ExecuteExpressionBefore(scope->begin, [&]() { - scope->CreateKernel(); - }); -#ifdef _RELWITHDEBINFO - if (scope->boundary_nodes.size() > max_kernel_memory_dependencies) { - cout << current_pass << ": Warning: Kernel created with " << scope->boundary_nodes.size() << " boundary nodes" << endl; - } -#endif - } - - UpdateGraph(); -} - - -unordered_set IR::ComputeKernelDependencies(Node* kernel) { - unordered_set kernel_deps; - - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - //go over all inputs - unordered_set node_deps; - for (auto& [id, from] : node->args.Inputs()) { - if(id.first == ArgType::Shape) { - continue; - } - bool is_outside = !from->HasParent(kernel); - if (is_outside) { - //check if its not scalar - if(from->args.Count(ArgType::Shape) == 0) { - continue; - } - node_deps.insert(from); - } else { - node_deps.insert(from->memory_deps.begin(), from->memory_deps.end()); - } - } - //go over all outputs and add this node as a dependency if those outputs are outside the kernel - //this means the kernel must also write to an output - for (auto [edge, to] : node->args.Outputs()) { - if(edge.first.first == ArgType::Shape) { - continue; - } - if (!to->HasParent(kernel)) { - node_deps.insert(node.get()); - } - } - node->memory_deps = node_deps; - kernel_deps.insert(node_deps.begin(), node_deps.end()); - } - - kernel->memory_deps = kernel_deps; - return kernel_deps; -} - - -// check if all child nodes in a kernel have compatible shape to the kernel -void IR::CheckKernelShapes() { - // get kernels - vector kernels = GetNodesOfType("kernel"); - - // go over all outputs of each kernel and create memory nodes to store the - // output - for (auto kernel : kernels) { - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - // check if the node has a shape argument - ShapeCompareResult result = CompareShape(kernel, node.get(), true); - } - -#ifdef _RELWITHDEBINFO - auto deps = ComputeKernelDependencies(kernel); - if(deps.size() > max_kernel_memory_dependencies) { - cout << current_pass << ": Warning: Kernel " << kernel->debug_index << " has " << deps.size() << " dependencies" << endl; - } -#endif - } - - - UpdateGraph(); -} - -void IR::UpdateKernelShapes() { - // get kernels - vector kernels = GetNodesOfType("kernel"); - - // go over all outputs of each kernel and create memory nodes to store the - // output - for (auto kernel : kernels) { - NodeArguments kernel_shape = kernel->args.GetArguments(ArgType::Shape); - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - //set the shape of all nodes in the kernel to the kernel shape - node->args.RemoveArguments(ArgType::Shape); - node->args.AddArguments(kernel_shape); - } - } - - UpdateGraph(); -} - -bool IR::LimitKernelMemoryDependencies() { - UpdateGraph(); - vector kernels = GetNodesOfType("kernel"); - - int created_kernels = 0; - - for (auto kernel : kernels) { - unordered_set kernel_deps = ComputeKernelDependencies(kernel); - } - - for (auto kernel : kernels) { - if(kernel->memory_deps.size() <= max_allowed_memory_dependencies) continue; - - //throw std::runtime_error("Kernel " + to_string(kernel->index_) + " has too many memory dependencies (" + to_string(kernel_deps.size()) + " > " + to_string(max_kernel_memory_dependencies) + ")"); - } - - UpdateGraph(); - - return created_kernels == 0; -} - -void IR::UnrollOperations() { - UpdateGraph(); - vector kernels = GetNodesOfType("kernel"); - - Tensor* const_one; - Tensor* const_zero; - - ExecuteExpressionFirstChild(root, [&]() { - const_one = &Tensor::Constant(1, TFTypeInt32); - const_zero = &Tensor::Constant(0, TFTypeInt32); - }); - - vector> nodes_to_store; - for (auto kernel : kernels) { - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - if(node->name == "dim_id" || node->name == "const") { - continue; - } - // check if the node has a shape argument - ShapeCompareResult result = CompareShape(kernel, node.get(), true); - if (!result.exactly_compatible) { - if(!result.unroll_compatible) continue; - ShapeInfo unroll_shape = node->tensor_->GetShapeInfo(); - vector dims = unroll_shape.GetShape(); - vector> unrolled_dims; - for (int broadcast_dim: result.broadcast_dims) { - unrolled_dims.push_back({dims[broadcast_dim], broadcast_dim}); - } - int unrolled_size = 1; - for (int i = 0; i < unrolled_dims.size(); i++) { - unrolled_size *= unrolled_dims[i].first; - } - cout << "Unrolling node " << node->name << " in kernel " << kernel->name << endl; - - ExecuteExpressionBefore(*node, [&] { - //create local memory with size = to size difference - Tensor* local_memory = nullptr; - if(node->format.type != TFType::None) { - local_memory = &Tensor::LocalMemory(unrolled_size, node->format); - } - - //create loops for each broadcasted dimension - Node* last_loop = nullptr; - map loop_indices; - for (int i = 0; i < unrolled_dims.size(); i++) { - Tensor* const_loop_size = &Tensor::Constant(unrolled_dims[i].first, TFTypeInt32); - if(last_loop == nullptr) { - last_loop = Tensor::Loop(*const_zero, *const_loop_size, *const_one).node_; - } else { - ExecuteExpressionFirstChild(last_loop, [&] { - last_loop = Tensor::Loop(*const_zero, *const_loop_size, *const_one).node_; - }); - } - loop_indices[unrolled_dims[i].second] = last_loop; - } - - //go over inputs of the node and replace them with the loop indices if they are dim_id's of the right dimensions - //if the input is a local memory, then load the value from the memory - for (auto& [id, from] : node->args.InputsCopy()) { - if(from->name == "dim_id") { - int dim = from->data[0]; - if(result.broadcast_dims.contains(dim)) { - //replace the input with the loop index - Node* loop_index = loop_indices[dim]; - node->args.UpdateArgument(id, loop_index); - } - } else if(from->op->HasAllTypes(OpProp::LocalMemory)){ - ExecuteExpressionLastChild(last_loop, [&] { - //load the value from the memory - //TODO compute proper index - Tensor* load_value = &Tensor::Load(*from->tensor_, {const_zero}); - node->args.UpdateArgument(id, load_value->node_); - }); - } - } - - //move this node inside the loop - MoveNodeTo(last_loop->GetLastChild(), node.get()); - - if(local_memory != nullptr) { - //replace all outputs of the node with the local memory - node->ReplaceThisWithGivenNode(local_memory->node_); - nodes_to_store.push_back({local_memory->node_, node.get()}); - } - - //TODO special handling of global mem load operations - }); - - } - } - } - - UpdateGraph(); - - for (auto [local_memory, node] : nodes_to_store) { - ExecuteExpressionAfter(node, [&] { - Tensor* store_value = &Tensor::Store(*local_memory->tensor_, *node->tensor_, {const_zero}); - }); - } -} - -void IR::CheckIR(string name, bool check_clustering, bool check_kernels) { -#ifdef NDEBUG - return; -#endif - UpdateGraph(); - - map invalid_nodes; - //check if the IR is clusterized correctly - for (auto node = begin(); !node.end(); node.next()) { - bool identity = node->args.Count(ArgType::Index) == 0; - - Node* prev = node->prev; - - if (prev == nullptr) continue; - - - // go over all inputs - for (auto& [id, input] : node->args.Inputs()) { - Node* to = node.get(); - - // check if inputs are before the node - if (input->index_ >= to->index_ && input->name != "const") { - if (id.first != ArgType::Shape) { - invalid_nodes[to] = "Argument " + TypeToString(id.first) + ":" + - to_string(id.second) + " is after the node"; - } - } - } - } - - string listing = PrintListing(invalid_nodes); - - if (!invalid_nodes.empty()) { - listing += "Step [" + name + "] failed. "; - throw std::runtime_error(listing); - } else { - cout << "Step [" << name << "] completed successfully: \n" << endl; - cout << listing << endl; - } -} - -void IR::ReorderOperations() { - // get kernel data - vector kernels = GetNodesOfType("kernel"); - - for (auto* kernel: kernels) { - unordered_set nodes_to_move; - // go over all nodes in the kernel and check if their inputs can be copied - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - // go over all inputs - for (auto& [id, from] : node->args.Inputs()) { - bool outside_kernel = !from->HasParent(kernel); - if (outside_kernel && !node->args.CannotMoveArgument(id)) { - // if this node is a set and its input is outside of the cluser -> - // move it inside - if (node->op->HasAllTypes(OpProp::Set)) { - nodes_to_move.insert(from); - } - } - } - } - - //TODO (Moroz): do a check on order of the moved nodes - seems to be breaking sometimes - - // move all the nodes that are outside the kernel inside - Node* kernel_begin = kernel->child; - for (auto* node : nodes_to_move) { - MoveNodeTo(kernel_begin, node); - } - } - - UpdateGraph(); -} - - -/// -/// Copy nodes together with their arguments (as far as possible) -/// -/// nodes to copy -/// if given, the indices to use -/// map between the original nodes and the copied nodes -map IR::CopyComputation( - const unordered_set& targets, const unordered_map& indices) { - - // do a depth first search to copy all the nodes required for the targets - // (only if in the same kernel) - set nodes_to_copy; - bool valid = true; - std::function dfs = [&](Node* node) { - if (nodes_to_copy.contains(node)) return; - nodes_to_copy.insert(node); - for (auto& [arg, from] : node->args.Inputs()) { - if (node->args.CannotCopyArgument(arg)) { - continue; - } - dfs(from); - } - }; - - for (Node* target : targets) { - dfs(target); - } - - - return CopyNodes(nodes_to_copy, {}, indices, targets, true); -} - -map IR::CopyNodesWithIndex(unordered_set nodes_to_copy, - unordered_map indices, - Node *cursor) { - // copy all the nodes at the beginning of the kernel - map copied_node_map; - if(cursor == nullptr) { - copied_node_map = CopyComputation(nodes_to_copy, indices); - } else { - ExecuteExpressionBefore(cursor, [&]() { - copied_node_map = CopyComputation(nodes_to_copy, indices); - }); - } - return copied_node_map; -} - -void IR::CopyArguments(ArgEdges args_to_copy, Node* cursor) -{ - unordered_set nodes_to_copy; - for (auto& [arg, out] : args_to_copy) { - nodes_to_copy.insert(arg.second); - } - - // copy all the nodes at the beginning of the kernel - map copied_node_map; - unordered_map indices; - copied_node_map = CopyNodesWithIndex(nodes_to_copy, indices, cursor); - - ReplaceArgs(args_to_copy, copied_node_map); -} - -void IR::MoveShapeOutsideKernels() { - UpdateGraph(); - // find all nodes that are used as shapes and are inside kernels - map nodes_to_copy; - for (auto node = begin(); !node.end(); node.next()) { - Node* kernel = node->GetParent("kernel"); - if (kernel == *node) { //if returns itself, then no kernel parent found - continue; - } - - // go over all outputs arguments - for (auto [edge, to] : node->args.Outputs()) { - auto& [id, from] = edge; - if (id.first != ArgType::Shape) { - continue; - } - // add the node to the set - nodes_to_copy[node.get()] = kernel; - } - } - - for (auto [ node, kernel ] : nodes_to_copy) { - //get all output arguments that are shapes - ArgEdges args_to_copy; - int earliest_output_index = INT_MAX; - Node* earliest_output = nullptr; - for (auto [edge, to] : node->args.Outputs()) { - auto& [id, from] = edge; - if (id.first == ArgType::Shape) { - args_to_copy.insert(ArgEdge(Arg(id, node), to)); - - //get the earliest output - if (to->index_ < earliest_output_index) { //wat - earliest_output_index = to->index_; - earliest_output = to; - } - } - } - - Node* common_parent = earliest_output->GetNodeWithCommonParent(kernel); - - // copy shape computation and put it before the earliest output (outside of the kernel if its inside) - CopyArguments(args_to_copy, common_parent); - ApplyChanges(false); - } -} - - -/// -/// Get all inputs of this program in the IR -/// -void IR::GetInputList() { - int input_memory_index = 0; - //MUST BE IN ORDER - for (auto node = begin(); !node.end(); node.next()) { - if (node->flags.has(NodeProp::InputMemory)) { - shape_memory_map[*node] = {}; - // add shapes to the memory inputs - for (int i = 0; i < node->args.Count(ArgType::Shape); i++) { - Node* shape_node = node->args.Get(ArgType::Shape, i); - shape_memory_map[*node][i] = shape_node; - } - - // set input memory index - int input_index = input_memory_index++; - // add shapes to the memory inputs - input_memory_map[input_index] = *node; - node->flags.set(NodeProp::InputMemory, (int64_t)input_index); - //if any of the inputs are "input_shape" then we need to add the input index to them - for (auto& [arg, from] : node->args.Inputs()) { - if (arg.first == ArgType::Shape && from->name == "input_shape") { - if(!from->flags.has(NodeProp::InputShapeMemory)) { //ONLY FIRST TIME - from->flags.set(NodeProp::InputShapeMemory, (int64_t)input_index); - } - } - } - } - } -} - -/// -/// Get all outputs of this program in the IR -/// -void IR::GetOutputList() { - for (auto node = begin(); !node.end(); node.next()) { - if (node->flags.has(NodeProp::OutputMemory)) { - if (!node->op->HasAllTypes(OpProp::Memory)) { - throw std::runtime_error( - "Compilation error: output is not a memory node"); // all outputs - // should be - // memory nodes - // at this point - } - output_memory_map[(int)node->flags.get(NodeProp::OutputMemory)] = *node; - } - if (node->op->HasAllTypes(OpProp::Modifier, OpProp::MemoryOp)) { - if (!node->HasParent("kernel")) { - writebacks++; - } - } else if (node->op->HasAllTypes(OpProp::Load, OpProp::MemoryOp)) { - if (!node->HasParent("kernel")) { - readbacks++; - } - } - } -} - -/// -/// Compute statistics about the IR -/// -void IR::ComputeStatistics() { - for (auto node = begin(); !node.end(); node.next()) { - if (node->name == "memory") { - bool is_input = node->flags.has(NodeProp::InputMemory); - bool is_output = node->flags.has(NodeProp::OutputMemory); - if (is_input) { - input_memory_count++; - } - if (!is_input && !is_output) { - temp_memory_count++; - } - } - } - - //Check if output memory map has all the outputs - if (output_memory_map.size() < output_memory_count) { - throw std::runtime_error("Output memory map does not have all the outputs, some got lost"); - } else if (output_memory_map.size() > output_memory_count) { - throw std::runtime_error("Output memory map has more outputs than expected"); - } -} - - - -unordered_set IR::GetDependencies(unordered_set nodes) { - unordered_set dependencies; - std::function dfs = [&](Node* node) - { - if (dependencies.contains(node)) { - return; - } - - dependencies.insert(node); - - //all inputs of this node are used - for (auto& [arg, from] : node->args.Inputs()) { - dfs(from); - } - - //if the node is a memory node or used as memory, then all outputs are used - for (auto [edge, to] : node->args.Outputs()) { - auto& [id, from] = edge; - if (to->args.IsChangingInput(id)) { - dfs(to); - } - } - }; - - for(auto node : nodes) { - dfs(node); - } - - return dependencies; -} - -void IR::ComputeNodeCost() -{ - for (auto node = begin(); !node.end(); node.next()) { - bool is_memory = node->op->HasAllTypes(OpProp::Memory); - unordered_map input_costs; - for (auto& [id, from] : node->args.Inputs()) { - if (id.first != ArgType::Memory && - (id.first != ArgType::Shape && !is_memory)) { - input_costs[from] = from->cost_; - } - } - float input_cost = node->op->GetCost(); - for (auto& input : input_costs) { - input_cost += abs(input.second); - } - node->cost_ = input_cost; - - //go over outputs and check if it has any load operations - bool is_used_as_memory = false; - for (auto [edge, to] : node->args.Outputs()) { - auto& [id, from] = edge; - if(id.first == ArgType::Memory) { - is_used_as_memory = true; - break; - } - } - - if(is_used_as_memory && input_cost > 128.0) { - node->flags.set(NodeProp::NoCopyFusion); - } - } -} - -map IR::GetKernelOutputs(Node *kernel) -{ - map node_output; - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - bool is_output = node->flags.has(NodeProp::OutputMemory); - ArgEdges outputs = ArgEdges(); - - for (auto [edge, to] : node->args.Outputs()) { - auto& [id, from] = edge; - if (to == nullptr) continue; - // if is a shape or memory argument, then skip (shape is loaded on CPU) - if (id.first == ArgType::Shape) continue; - if (!to->HasParent(kernel)) { - outputs.emplace(Arg(id, *node), to); - is_output = true; - } - } - - if (is_output) { - node_output[*node] = outputs; - } - } - - return node_output; -} - -string IR::PrintListing(map node_debug) const { - return GetOperationListing(*this, false, node_debug) + "\n\n"; -} - -string IR::GetNodeListing(Node* node) const { - return GetNodeString(node, true); -} - -/// -/// Copy given nodes -/// -/// target nodes to copy -/// if given, the arguments to replace -/// if given, the indices to use -/// if given, the target nodes -/// if true, all nodes and their arguments must be copied -/// mappings between the original nodes and the copied nodes -map IR::CopyNodes( - set nodes_to_copy, - unordered_map argument_replacements, - unordered_map indices, - unordered_set targets, bool must_copy_all) { - - // if we have indices, we are basically rerunning the computation with a - // different set of indices (of possible different shape) - bool can_change_shape = !indices.empty(); - NodeArguments shape_args = NodeArguments(); - if (can_change_shape) { - // get first index - int first_index = indices.begin()->first; - Node* first_index_node = indices.at(first_index); - shape_args = first_index_node->args.GetArguments(ArgType::Shape); - } - - if (nodes_to_copy.empty()) { - return {}; - } - - //TODO: figure out a better heuristic for this, or add as a warning - //(need to implement a logging system) - // if (nodes_to_copy.size() > 1024) { - // throw std::runtime_error( - // "Copy Nodes: Copying too many nodes, something is probably " - // "wrong. Number of nodes to copy: " + - // to_string(nodes_to_copy.size())); - // } - - // copy the nodes - map copied_node_map; - for (auto node = begin(); !node.end(); node.next()) { - if (!nodes_to_copy.contains(node.get())) { - continue; - } - - Node* new_node; - - // if we have the index, use it instead - bool is_dim = node->name == "dim_id"; - bool no_index = true; - if (is_dim) { - int dim = node->data[0]; - if (indices.contains(dim)) { - Node* new_index = indices.at(dim); - if(new_index == nullptr) { - throw std::runtime_error("Copy Nodes: New index is null for node " + node->name); - } - new_node = new_index; - no_index = false; - } - } - - if (no_index) { - // create new arguments - NodeArguments new_args; - for (auto& [arg, from]: node->args.Inputs()) { - auto& [type, index] = arg; - if (can_change_shape && type == ArgType::Shape) { - continue; - } - - // if shape or memory argument, then no need to use copied node - if (node->args.CannotCopyArgument(arg) && !targets.contains(from) && !argument_replacements.contains(from)) { - if(from == nullptr) { - throw std::runtime_error("Copy Nodes: From is null for node " + node->name); - } - new_args[arg] = from; - continue; - } - - Node* new_from = from; - - if (argument_replacements.contains(from)) { - new_from = argument_replacements[from]; - } else if (nodes_to_copy.contains(from)) { - if(!copied_node_map.contains(from)) { - throw std::runtime_error("Copy Nodes: No replacement for node " + from->name); - } - new_from = copied_node_map[from]; - } else if (must_copy_all) { - throw std::runtime_error("Copy Nodes: No replacement for node " + from->name + " but we must copy all nodes"); - } - - if(new_from == nullptr) { - throw std::runtime_error("Copy Nodes: New from is null for node " + from->name); - } - - // create new argument - new_args[arg] = new_from; - } - - if (can_change_shape) { - new_args.insert(shape_args.begin(), shape_args.end()); - } - - // create new node - Tensor* tensor = Tensor::GetCopy(*node->GetTensor(), new_args); - new_node = tensor->node_; - } - - if(new_node == nullptr) { - throw std::runtime_error("Copy Nodes: New node is null for node " + node->name); - } - - copied_node_map[node.get()] = new_node; - } - - return copied_node_map; -} - - -void IR::AddNodeLoadOperations(Node* node, Node* kernel, Tensors indices) { - for (auto& [arg, input_node] : node->args.InputsCopy()) { - if (arg.first == ArgType::Memory || arg.first == ArgType::Shape) - continue; - - bool is_in_a_kernel = input_node->HasParent("kernel"); - bool is_outside = !input_node->HasParent(kernel); - bool is_memory = input_node->op->HasAllTypes(OpProp::Memory); - - if (is_memory || (is_in_a_kernel && is_outside)) { - // load the memory node before this node - ExecuteExpressionBefore(node, [&]() { - Tensor& loaded = Tensor::Load(*input_node->GetTensor(), indices, IndexingMode::Unsafe); - node->args.UpdateArgument(arg, loaded.node_); - }); - } - } -} - -void IR::AddKernelGlobalLoadOperations() { - // get kernels - vector kernels = GetNodesOfType("kernel"); - for (auto kernel : kernels) { - - // replace all inputs pointing to memory nodes with the memory node - unordered_set nodes_to_load; - unordered_map load_arguments; - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - for (auto& [arg, input_node] : node->args.Inputs()) { - if (arg.first == ArgType::Memory || arg.first == ArgType::Shape) - continue; - - bool is_in_a_kernel = input_node->HasParent("kernel"); - bool is_outside = !input_node->HasParent(kernel); - bool is_memory = input_node->op->HasAllTypes(OpProp::Memory); - - if (is_memory || (is_in_a_kernel && is_outside)) { - nodes_to_load.insert(input_node); - load_arguments[input_node].insert(ArgEdge(Arg(arg, input_node), node.get())); - } - } - } - - for (auto node : nodes_to_load) { - // load the memory node at the beginning of the kernel - ExecuteExpressionFirstChild(kernel, [&]() { - Tensor& loaded = Tensor::Load(*node->GetTensor(), {}, IndexingMode::Unsafe); - for (auto [in, out] : load_arguments[node]) { - auto& [arg, from] = in; - out->args.UpdateArgument(arg, loaded.node_); - } - }); - } - } - - UpdateGraph(); -} - - -void IR::AddMemoryOpIndices() { - // get kernels - vector kernels = GetNodesOfType("kernel"); - for (auto kernel : kernels) { - // get kernel shape arguments - NodeArguments shape_args = kernel->args.GetArguments(ArgType::Shape); - - Tensors indices = Tensors(); - // add dimension index nodes - ExecuteExpressionFirstChild(kernel, [&]() { - for (int i = 0; i < shape_args.size(); i++) { - indices.push_back(&Tensor::Index(shape_args, i)); - } - }); - int kernel_dim = (int)shape_args.size(); - - // replace all inputs pointing to memory nodes with the memory node - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - if (!node->op->HasAllTypes(OpProp::MemoryOp)) { - continue; - } - - Node* input_node = node->args.Get(ArgType::Memory); - map shape = input_node->args.GetTensors(ArgType::Shape); - - int memory_dim = (int)shape.size(); - ExecuteExpressionBefore(node.get(), [&]() { - for (int i = 0; i < memory_dim; i++) { - if (node->args.Has(ArgType::Index,i)) { - continue; - } - if (memory_dim > kernel_dim) { - throw std::runtime_error( - "Memory dimension is greater than kernel dimension, we can't " - "implicitly broadcast"); - } - const Tensor* index = nullptr; - // if the shape is 1, then we broadcast and set the index to 0 - if (isConstantAndEqualTo(shape[i], 1.0)) { - index = &Tensor::Constant(0); - } else { - index = indices[i]; - } - node->args.AddArgument(ArgType::Index, i, index->node_); - } - }); - } - } - - UpdateGraph(); -} - -void IR::AddKernelGlobalStoreOperations() { - UpdateGraph(); - // get kernels - vector kernels = GetNodesOfType("kernel"); - - // go over all outputs of each kernel and create memory nodes to store the - // output - for (auto kernel: kernels) { - map node_output = GetKernelOutputs(kernel); - - for (auto [output, args] : node_output) { - // if the output is already a memory node, then skip - if (output->op->HasAllTypes(OpProp::Memory)) { - continue; - } - - Node* mem; - // add memory node before this kernel - ExecuteExpressionBefore(kernel, [&]() { - mem = Tensor::Memory(kernel->args.GetArguments(ArgType::Shape), output->format).node_; - mem->debug_name = output->debug_name; - - if (output->flags.has(NodeProp::OutputMemory)) { - mem->flags.copy_all_given(output->flags, { NodeProp::OutputMemory }); - output->flags.remove(NodeProp::OutputMemory); - } - }); - - // go over all outputs of this node and replace their input with the - // memory node - for (auto& [arg, to] : args) { - auto id = arg.first; - if (id.first != ArgType::Shape && - id.first != ArgType::Memory) { - // if not a memory or shape argument, then the memory needs to be - // loaded before the node - ExecuteExpressionBefore(to, [&]() { - Tensor& loaded = Tensor::Load(*mem->GetTensor(), {}, IndexingMode::Unsafe); - // the node must now use the loaded value - to->args.UpdateArgument(id, loaded.node_); - }); - } else { - // otherwise the memory can be used directly - to->args.UpdateArgument(id, mem); - } - } - - //get last modification of the memory - Node* last_mod = output->GetFinalVersion(); - //get the parent of the last modification on the same level as the memory - Node* last_mod_parent = last_mod->GetNodeWithCommonParent(output); - - // add store node after the last modification on the same level as the memory - ExecuteExpressionAfter(last_mod_parent, [&]() { - // add store node after this node - Tensor* store = &Tensor::Store(*mem->GetTensor(), *output->GetTensor(), {}, IndexingMode::Unsafe); - }); - } - } - - // replace all inputs pointing to memory nodes with the memory node - for (auto node = begin(); !node.end(); node.next()) { - bool is_memory = node->op->HasAllTypes(OpProp::Memory); - - for (auto& [id, from] : node->args.InputsCopy()) { - if (id.first == ArgType::Memory || - (id.first == ArgType::Shape && !is_memory)) - continue; - - if (from->op->HasAllTypes(OpProp::Memory)) { - // load the memory node before this node - ExecuteExpressionBefore(node.get(), [&]() { - Tensor& loaded = Tensor::Load(*from->GetTensor(), {}, IndexingMode::Unsafe); - node->args.UpdateArgument(id, loaded.node_); - }); - } - } - } - - UpdateGraph(); -} - -void IR::AddMemoryDeallocation() -{ - vector memory_nodes = GetNodesOfType("memory"); - - // go over all outputs of each memory and and put a deallocation node after the last time it is used - for (auto memory : memory_nodes) { - // skip input and output memories, they are deallocated manually - if (memory->flags.has(NodeProp::InputMemory)) { - continue; - } - - Node* last_output = nullptr; - int last_output_index = -1; - - bool is_an_output = false; - - //do a dfs to find the last output - std::function dfs = [&](Node* node) { - if (node->flags.has(NodeProp::OutputMemory)) { - is_an_output = true; - return; - } - - for (auto [edge, to] : node->args.Outputs()) { - auto& [id, from] = edge; - if (to->op->HasAllTypes(OpProp::MemoryReuse)) { - dfs(to); - } else { - if (last_output_index < to->index_) { - last_output_index = to->index_; - last_output = to; - } - } - } - }; - - dfs(memory); - - if (is_an_output) { - continue; - } - - // need to add deallication in the same scope as the allocation - Node* deallocation_point = last_output->GetNodeWithCommonParent(memory); - - // add deallocation node after the last time the memory is used - ExecuteExpressionAfter(deallocation_point, [&]() { - Tensor* deallocate = &Tensor::Deallocate(*memory->GetTensor()); - }); - } - - UpdateGraph(); -} - -// compute the flat index (in C-order) -Tensor* ComputeFlatIndex(NodeArguments memory_shape, Tensors indices, map idx, int memory_dim, IndexingMode mode = IndexingMode::Clamp) -{ - if (memory_dim == 0) - { - return &Tensor::Constant(0); - } - - int kernel_dim = (int)indices.size(); - - function get_shape = [&](int dim) { - dim = memory_dim - dim - 1; - return memory_shape[ArgID(ArgType::Shape, dim)]->GetTensor(); - }; - - // function to get index for given dimension, if not found then return - // default dim index - function get_index = [&](int dim) { - int idxdim = memory_dim - dim - 1; - Tensor* out; - const Tensor& shape = *get_shape(dim); - if (idx.find(idxdim) != idx.end()) { - out = const_cast(idx[idxdim]); - } else { - throw std::runtime_error("Finalize memory indexing: node index not found for dimension " + to_string(idxdim) + " in memory node with dimensions " + to_string(memory_dim)); - } - - //if index is uint then cast it to int - if (out->node_->format.type == Uint) { - out = &Tensor::toint(*out); - } - - switch (mode) - { - case IndexingMode::Clamp: - return &Tensor::clamp( - *out, TensorFrost::Tensor::Constant(0), - shape - TensorFrost::Tensor::Constant(1)); - case IndexingMode::Repeat: - return &(*out - (*out / shape) * shape); - case IndexingMode::Unsafe: - return out; - default: //TODO (Moroz): add other modes - throw std::runtime_error("Finalize memory indexing: invalid tensor indexing mode"); - } - }; - - Tensors shape = Tensors(); - Tensors index = Tensors(); - for (int i = 0; i < memory_dim; i++) { - shape.push_back(get_shape(i)); - index.push_back(get_index(i)); - } - - return &Tensor::FlatIndex(shape, index); -} - -void IR::ReplaceDimNodes(Node* kernel, Tensors indices, int dims) -{ - // replace all dim nodes with the corresponding index node - unordered_set nodes_to_remove; - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - if (node->name == "dim_id") { // remove the dim node - nodes_to_remove.insert(node.get()); - } - else - { - //go over node inputs and replace dim nodes with index nodes - for (auto& [id, from] : node->args.InputsCopy()) { - if (from->name == "dim_id") { - int dim = from->data[0]; - Node* index_node = nullptr; - if (dim >= dims) { //if dim node of dimension greater than the number of dimensions its 0 - ExecuteExpressionBefore(node.get(), [&]() { - index_node = Tensor::Constant(0).node_; - }); - } else { - index_node = indices[dim]->node_; - } - - // replace the dim node with the index node - node->args.UpdateArgument(id, index_node); - } - } - } - } - - // remove all dim nodes - for (auto* node : nodes_to_remove) { - RemoveNode(node); - } -} - -void IR::MultiDimensionalModeIndices(vector& indices, Node* kernel_, int dims, Tensors kernel_shape) -{ - //add dim_id nodes at the beginning of the kernel - ExecuteExpressionFirstChild(kernel_, [&]() { - for (int i = 0; i < dims; i++) { - indices[i] = &Tensor::Index(kernel_shape, i); - } - }); -} - -Tensors ComputeIndicesFromBlockIndex(Tensor* block_index, Node* kernel, - Tensors kernel_shape, int dims) { - //compute in-block index - vector block_size = kernel->group_size; - int block_dim = (int)block_size.size(); - Tensors block_size_tensors = {}; - for (int i = 0; i < block_dim; i++) { - block_size_tensors.push_back(&Tensor::Constant(block_size[i])); - } - vector in_block_indices; - for (int i = 0; i < block_dim; i++) { - in_block_indices.push_back(&block_index->BlockThreadIndex(i)); - } - - //compute out-of-block index - Tensors blocks_shape = {}; - for (int i = 0; i < block_dim; i++) { - const Tensor block_size = *block_size_tensors[i]; - const Tensor shape = *kernel_shape[i]; - Tensor& ceil = (shape + block_size - Tensor::Constant(1)) / block_size; - blocks_shape.push_back(&ceil); - blocks_shape[i]->SetDebugName("blocks_shape_" + to_string(i)); - } - for (int i = block_dim; i < dims; i++) { - blocks_shape.push_back(kernel_shape[i]); - blocks_shape[i]->SetDebugName("blocks_shape_" + to_string(i)); - } - Tensors out_block_indices = Tensor::IndicesFromFlatIndex(block_index, blocks_shape); - - //combine the final indices - Tensors indices = {}; - - for (int i = 0; i < block_dim; i++) { - indices.push_back(&(*out_block_indices[i] * *block_size_tensors[i] + *in_block_indices[i])); - indices[i]->SetDebugName("index_" + to_string(i)); - } - - for (int i = block_dim; i < dims; i++) { - indices.push_back(out_block_indices[i]); - indices[i]->SetDebugName("index_" + to_string(i)); - } - - return indices; -} - -Tensor* IR::LinearBlockModeIndices(Tensors& indices, Node* kernel_, int dims, Tensors kernel_shape) -{ - Tensor* block_index = nullptr; - Tensor* if_tensor = nullptr; - ExecuteExpressionFirstChild(kernel_, [&]() { - block_index = &kernel_->GetTensor()->BlockIndex(); - - if(kernel_->group_size.size() == 0) { //if group size is not set, then set it to default - switch (dims) - { - case 1: - kernel_->group_size = {256}; - break; - case 2: - kernel_->group_size = {16, 16}; - break; - case 3: - kernel_->group_size = {8, 8, 8}; - break; - default: - kernel_->group_size = {8, 8, 8}; - } - - //if the dimensions are known, then use the minimum of the group size and the shape to avoid useless computation - int group_dim = (int)kernel_->group_size.size(); - for (int i = 0; i < group_dim; i++) { - int shape = kernel_shape[i]->TryGetConstant(); - if (shape > 0) { - kernel_->group_size[i] = min(kernel_->group_size[i], shape); - } - } - } - - indices = ComputeIndicesFromBlockIndex(block_index, kernel_, kernel_shape, dims); - - //add a check for if inside the dispatch - Tensor* inside_dispatch = &(*indices[0] < *kernel_shape[0]); - for (int i = 1; i < dims; i++) { - inside_dispatch = &(*inside_dispatch && *indices[i] < *kernel_shape[i]); - } - inside_dispatch->SetDebugName("is_inside_dispatch"); - - //put an if condition - if_tensor = &Tensor::If(*inside_dispatch); - if_tensor->SetDebugName("if_inside_dispatch"); - }); - - ReplaceDimNodes(kernel_, indices, dims); - - return if_tensor; -} - -void IR::ComputeAddress(Node* node, Tensors indices) -{ - // get the input memory node - const Tensor* memory = node->args.GetTensor(ArgType::Memory); - - NodeArguments memory_shape = memory->node_->args.GetArguments(ArgType::Shape); - - int memory_dim = (int)memory_shape.size(); - - // get the index nodes - map idx = node->args.GetTensors(ArgType::Index); - - if (idx.empty()) - { - node->indexing_mode_ = IndexingMode::Unsafe; //we can guarantee that the index is in bounds - } - - auto flat_index= ComputeFlatIndex(memory_shape, indices, idx, memory_dim, node->indexing_mode_); - - // TODO(Moroz): add wrap mode - - // remove the index node edges - node->args.RemoveArguments(ArgType::Index); - - // add the flat index node edge - node->args.AddArgument(ArgType::Index, 0, flat_index->node_); -} - -void IR::FinalizeMemoryIndexing() { - vector kernels = GetNodesOfType("kernel"); - - vector dispatch_checks; - - for (auto kernel : kernels) { - Node* shape_node = kernel; - if (shape_node == nullptr) continue; - // load kernel shape - map kernel_shape_map = shape_node->args.GetTensors(ArgType::Shape); - Tensors kernel_shape; - for (auto& shape : kernel_shape_map) { - kernel_shape.push_back(shape.second); - } - - if (kernel_shape.empty()) { - // can skip if no kernel shape - no index - continue; - } - - // compute the index for each dimension - int dims = (int)kernel_shape.size(); - Tensors indices = Tensors(dims); - dispatch_checks.push_back( - LinearBlockModeIndices(indices, kernel, dims, kernel_shape)); - - // go over all nodes that take an index as input (e.g. load, store, atomic) - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - if (node->op->HasAllTypes(OpProp::MemoryOp)) { - if (node->flags.has(NodeProp::LocalMemoryOp)) { - continue; - } - ExecuteExpressionBefore(*node, [&]() { ComputeAddress(node.get(), indices); }); - } - } - } - - //now compute address for all nodes that are not in a kernel - for (auto node = begin(); !node.end(); node.next()) { - if (!node->HasParent("kernel") && node->op->HasAllTypes(OpProp::MemoryOp)) { - ExecuteExpressionBefore(node.get(), [&]() { - Tensors indices = {}; - ComputeAddress(node.get(), indices); - }); - } - } - - for (auto check : dispatch_checks) { - // put the rest of the kernel starting from if_node->next_ as a child - // MoveNodeTo(if_node->node_->child, if_node->node_->next); - Node* if_node = check->node_; - if_node->child = if_node->next; - if_node->next->prev = nullptr; - if_node->next->parent = if_node; - if_node->next = nullptr; - } - - UpdateGraph(); -} - - -void IR::TryReplaceModificationsWithVersions() -{ - UpdateGraph(); - - //get all "set" nodes - vector nodes = GetNodesOfType("set"); - - unordered_set nodes_to_remove; - - for (auto set_node : nodes) { - //look up the memory node - Node* memory_node = set_node->args.Get(ArgType::Memory); - Node* input_value = set_node->args.Get(ArgType::Input); - - //if this node has the same parent as the memory node, then it can be replaced with a version - if (memory_node->parent == set_node->parent) { - //replace the set node with the memory node - ExecuteExpressionBefore(set_node, [&]() { - Tensor& copied = Tensor::copy(*input_value->GetTensor()); - Node* copynode = copied.node_; - memory_node->ReplaceThisWithGivenNode(copynode, set_node->index_, true); - nodes_to_remove.insert(set_node); - }); - } - - UpdateIndex(); - } - - UpdateGraph(); - - // remove all nodes that are not used - for (auto* node : nodes_to_remove) { - RemoveNode(node); - } - - UpdateGraph(); -} - -void IR::ApplyChanges(bool update_graph, const Node* uroot) { - for (auto& [arg, out] : edgesToUpdate) { - Node* from = arg.second; - if (!replacementNodes.contains(from)) { - throw std::runtime_error("No replacement node found for node " + from->name); - } - Node* to = replacementNodes[from]; - out->args.UpdateArgument(arg.first, to); - } - - // remove all nodes that are not used - for (auto* node : removedNodes) { - RemoveNode(node); - } - - if (update_graph) { - UpdateGraph(uroot); - } - - ClearChanges(); -} - -void IR::ClearChanges() { - edgesToUpdate.clear(); - removedNodes.clear(); - replacementNodes.clear(); -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Compiler/Steps/Optimization.cpp b/TensorFrost/Compiler/Steps/Optimization.cpp deleted file mode 100644 index 64c15813..00000000 --- a/TensorFrost/Compiler/Steps/Optimization.cpp +++ /dev/null @@ -1,760 +0,0 @@ -#include "Compiler/KernelGen.h" - -namespace TensorFrost { - -bool isConstantAndEqualTo(const Tensor* tensor, float value) { - if (tensor->node_->name != "const" || tensor->node_->flags.has(NodeProp::Modified)) { - return false; - } - - switch (tensor->node_->format.type) { - case TFType::Float: - return AsFloat(tensor->node_->data[0]) == value; - case TFType::Int: - return AsInt(tensor->node_->data[0]) == value; - case TFType::Uint: - return tensor->node_->data[0] == value; - default: - throw std::runtime_error("Unexpected type in isConstantAndEqualTo"); - } -} - -bool isConstant(const Tensor* tensor) { - return tensor->node_->name == "const" && !tensor->node_->flags.has(NodeProp::Modified); -} - -Tensor* ApplyMultiOP(const Tensor* a, const Tensor* b, std::function opF32, std::function opI32, std::function opU32) { - switch (a->node_->format.type) { - case TFType::Float: - return &Tensor::Constant(opF32(AsFloat(a->node_->data[0]), AsFloat(b->node_->data[0]))); - case TFType::Int: - return &Tensor::Constant(opI32(AsInt(a->node_->data[0]), AsInt(b->node_->data[0]))); - case TFType::Uint: - return &Tensor::Constant(opU32(a->node_->data[0], b->node_->data[0])); - default: - throw std::runtime_error("Unexpected type in ApplyMultiOP"); - } -} - -Tensor* ApplyUnaryOP(const Tensor* a, std::function opF32, std::function opI32, std::function opU32) { - switch (a->node_->format.type) { - case TFType::Float: - return &Tensor::Constant(opF32(AsFloat(a->node_->data[0]))); - case TFType::Int: - return &Tensor::Constant(opI32(AsInt(a->node_->data[0]))); - case TFType::Uint: - return &Tensor::Constant(opU32(a->node_->data[0])); - default: - throw std::runtime_error("Unexpected type in ApplyUnaryOP"); - } -} - -#define ApplyOP(v1, v2, op) ApplyMultiOP(v1, v2, [](float a, float b) { return a op b; }, [](int a, int b) { return a op b; }, [](uint a, uint b) { return a op b; }) -#define ApplyFUNC(v1, v2, func) ApplyMultiOP(v1, v2, [](float a, float b) { return func(a, b); }, [](int a, int b) { return func(a, b); }, [](uint a, uint b) { return func(a, b); }) -#define ApplyUOP(v1, op) ApplyUnaryOP(v1, [](float a) { return op a; }, [](int a) { return op a; }, [](uint a) { return op a; }) -#define ApplyUFUNC(v1, func) ApplyUnaryOP(v1, [](float a) { return func(a); }, [](int a) { return func(a); }, [](uint a) { return func(a); } - -void IR::OptimizeOperations() -{ - for (auto node = begin(); !node.end(); node.next()) { - //get node operation - const string op = node->name; - - //get inputs - map inputs = node->args.GetTensors(ArgType::Input); - ExecuteExpressionAfter(*node, [&]() { - const Tensor* result = nullptr; - if (op == "add") { - // if any are zero, replace with the other - if (isConstantAndEqualTo(inputs[0], 0.0F)) { - // replace with input 1 - result = inputs[1]; - } else if (isConstantAndEqualTo(inputs[1], 0.0F)) { - // replace with input 0 - result = inputs[0]; - } - - // if all are constants, replace with result - if (isConstant(inputs[0]) && isConstant(inputs[1])) { - // replace with result - result = ApplyOP(inputs[0], inputs[1], +); - } - } else if (op == "sub") { - // if any are zero, replace with the other - if (isConstantAndEqualTo(inputs[0], 0.0F)) { - // replace with negation of input 1 - result = &(-*inputs[1]); - } else if (isConstantAndEqualTo(inputs[1], 0.0F)) { - // replace with input 0 - result = inputs[0]; - } - - // if all are constants, replace with result - if (isConstant(inputs[0]) && isConstant(inputs[1])) { - // compute result - result = ApplyOP(inputs[0], inputs[1], -); - } - - //if both are the same node, then replace with zero - if (inputs[0]->node_ == inputs[1]->node_) { - result = &Tensor::Constant(0u, inputs[0]->node_->format); - } - } else if (op == "mul") { - // if any are zero, replace with zero - if (isConstantAndEqualTo(inputs[0], 0.0F) || - isConstantAndEqualTo(inputs[1], 0.0F)) { - // replace with zero - result = &Tensor::Constant(0u, inputs[0]->node_->format); - } - - // if any are one, replace with the other - if (isConstantAndEqualTo(inputs[0], 1.0F)) { - // replace with input 1 - result = inputs[1]; - } else if (isConstantAndEqualTo(inputs[1], 1.0F)) { - // replace with input 0 - result = inputs[0]; - } - - // if all are constants, replace with result - if (isConstant(inputs[0]) && isConstant(inputs[1])) { - // compute result - result = ApplyOP(inputs[0], inputs[1], *); - } - } else if (op == "div") { - // if first is zero, replace with zero - if (isConstantAndEqualTo(inputs[0], 0.0F)) { - // replace with zero - result = &Tensor::Constant(0u, inputs[0]->node_->format); - } - - // if second is one, replace with first - if (isConstantAndEqualTo(inputs[1], 1.0F)) { - // replace with input 0 - result = inputs[0]; - } - - // if all are constants, replace with result - if (isConstant(inputs[0]) && isConstant(inputs[1])) { - // compute result - result = ApplyOP(inputs[0], inputs[1], /); - } - } - else if (op == "clamp") { - // if all are constants, replace with result - if (isConstant(inputs[0]) && isConstant(inputs[1]) && isConstant(inputs[2])) { - // compute result - result = ApplyFUNC(inputs[0], inputs[1], max); - result = ApplyFUNC(result, inputs[2], min); - } - } - else if(op == "neg") { - if(isConstant(inputs[0])) { - result = ApplyUnaryOP(inputs[0], [](float a) { return -a; }, [](int a) { return -a; }, [](uint a) { return a; }); - } - } - else if(op == "dim_id") { //if the shape of the dimension is 1 then replace with 0 - int dim = node->data[0]; - const Tensor* shape = node->args.Get(ArgType::Shape, dim)->GetTensor(); - if(isConstantAndEqualTo(shape, 1.0F)) { - result = &Tensor::Constant(0u, TFTypeInt32); - } - } - //TODO (Moroz): add more optimizations - - // if computed optimized result, replace all node references with it - if (result != nullptr) - { - node->ReplaceThisWithGivenNode(result->node_, -1, false, false); - } - }); - } -} - -void IR::OptimizeHostValuesWithHints() -{ - for (auto node = begin(); !node.end(); node.next()) { - //if node inside kernel - skip - if(node->HasParent("kernel")) continue; - - ExecuteExpressionAfter(*node, [&]() { - const Tensor* result = nullptr; - - //if node has a max value hint, then replace it with it - if(node->flags.has(NodeProp::HintMaxValue)) { - int64_t max_value = node->flags.get(NodeProp::HintMaxValue); - result = &Tensor::Constant((uint)max_value, node->format); - } - - if (result != nullptr) - { - for (auto [edge, to] : node->args.OutputsCopy()) { - auto& [id, from] = edge; - //if(to->HasParent("kernel")) continue; #TODO (Moroz): check if this is needed - to->args.UpdateArgument(id, result->node_); - } - } - }); - } - - UpdateGraph(); -} - -void IR::RemoveUnusedOperations() { - unordered_set used_nodes; - //mark all output nodes as used - for (auto node = begin(); !node.end(); node.next()) { - if (node->flags.has(NodeProp::OutputMemory) || - node->flags.has(NodeProp::InputMemory) || - node->op->HasAllTypes(OpProp::Static)) { - used_nodes.insert(node.get()); - } - } - - used_nodes = GetDependencies(used_nodes); - - // remove all nodes that are not used - unordered_set nodes_to_remove; - for (auto node = begin(); !node.end(); node.next()) { - if (!used_nodes.contains(node.get())) { - if (!node->flags.has(NodeProp::InputMemory) && !node->flags.has(NodeProp::OutputMemory)) - { - nodes_to_remove.insert(node.get()); - } - } - } - - for (auto* node : nodes_to_remove) { - RemoveNode(node); - } - - UpdateGraph(); -} - - -void IR::RemoveUnusedKernels() -{ - vector kernels = GetNodesOfType("kernel"); - vector nodes_to_remove; - - for (auto kernel : kernels) { - // remove all kernel nodes that dont do anything - int memory_modifiers = 0; - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - if (node->op->HasAllTypes(OpProp::Modifier, OpProp::MemoryOp)) { - memory_modifiers++; - } - //if any output is outside the kernel, then the kernel is needed - for (auto [edge, to] : node->args.Outputs()) { - auto& [id, from] = edge; - if (!to->HasParent(kernel)) { - memory_modifiers++; - } - } - } - if (memory_modifiers == 0) nodes_to_remove.push_back(kernel); - } - - // remove all nodes that are not used - for (auto* node : nodes_to_remove) { - RemoveNode(node); - } - - UpdateGraph(); -} - -#define MAX_UNROLL_NODES 128 -void IR::UnrollLoops(int max_iterations) -{ - vector loops = GetNodesOfType("loop"); - - unordered_set loops_to_remove; - - for (auto loop : loops) { - //get inputs (begin, end, step) - map inputs = loop->args.GetTensors(ArgType::Input); - - //try get the constant values - bool is_const = isConstant(inputs[0]) && isConstant(inputs[1]) && isConstant(inputs[2]); - - bool has_other_loops = loop->HasChild("loop") || loop->HasParent("loop"); - bool has_child_kernel = loop->HasChild("kernel"); - - if (!is_const || has_other_loops) { - continue; - } - - int begin = inputs[0]->TryGetConstant(); - int end = inputs[1]->TryGetConstant(); - int step = inputs[2]->TryGetConstant(); - - //how many iterations to unroll - int iters = (end - begin) / step; - if (iters > max_iterations) { - continue; - } - - //get all children of the loop - vector children = GetChildren(loop); - - if (children.size() > MAX_UNROLL_NODES) { -#ifdef _RELWITHDEBINFO - cout << current_pass << ": Warning: Loop has too many children to unroll" << endl; -#endif - continue; - } - - //check if they are not keywords or have no children - set nodes_to_copy; - bool can_unroll = true; - for (auto child : children) { - if (child->op->class_ == OpClass::Keyword || child->child->valid()) { - can_unroll = false; - break; - } - nodes_to_copy.insert(child); - } - - if (!can_unroll) { - continue; - } - - //unroll the loop - ExecuteExpressionAfter(loop, [&]() { - for (int i = begin; i < end; i += step) { - unordered_map arg_remap; - Tensor* index = &Tensor::Constant(i); - //index->SetDebugName(loop->debug_name + "_unroll_" + to_string(i)); - arg_remap[loop] = index->node_; - CopyNodes(nodes_to_copy, arg_remap, {}, {}, false); - } - }); - - //mark the loop for removal - loops_to_remove.insert(loop); - } - - // remove all loops that are not used - for (auto* loop : loops_to_remove) { - RemoveNode(loop); - } - - UpdateGraph(); -} - -void IR::UnrollAtomicOperations() { - vector atomics = GetNodesOfType(OpProp::Scatter, OpProp::MemoryOp, OpProp::Modifier); - - vector nodes_to_remove; - for (auto node: atomics) { - if(node->flags.has(NodeProp::LocalMemoryOp)) continue; - - std::set unused_dimensions; - Node* next_node = node->next; - int dim = node->args.Count(ArgType::Shape); - for (int i = 0; i < dim; i++) { - unused_dimensions.insert(i); - } - - //get the indices of the scatter operation - NodeArguments indices = node->args.GetArguments(ArgType::Index); - //get dependencies of all indices - unordered_set index_nodes; - for (auto& [id, index] : indices) { - index_nodes.insert(index); - } - unordered_set dependencies = GetDependencies(index_nodes); - //if any of the dependencies are a dim_id node, then its dimension(data[0]) is used - for (auto dep : dependencies) { - if (dep->name == "dim_id") { - unused_dimensions.erase(dep->data[0]); - } - } - - int unused_count = (int)unused_dimensions.size(); - - if (unused_count == 0) { - continue; - } - - auto old_shape = node->args.GetTensors(ArgType::Shape); - - //reduce the dimensions of the scatter operation - const Tensor* current_reduce = node->args.Get(ArgType::Input, 0)->GetTensor(); - vector unused_dims_vec(unused_dimensions.begin(), unused_dimensions.end()); - //sort the dimensions in descending order - std::reverse(unused_dims_vec.begin(), unused_dims_vec.end()); - bool supported_operation = true; - - BeginScope(next_node); - - for(int udim : unused_dims_vec) { - if(node->name == "InterlockedAdd") { - current_reduce = &Tensor::Sum(*current_reduce, udim); - } else if (node->name == "InterlockedMin") { - current_reduce = &Tensor::Min(*current_reduce, udim); - } else if (node->name == "InterlockedMax") { - current_reduce = &Tensor::Max(*current_reduce, udim); - } else if (node->name == "InterlockedAnd") { - current_reduce = &Tensor::All(*current_reduce, udim); - } else if (node->name == "InterlockedOr") { - current_reduce = &Tensor::Any(*current_reduce, udim); - } else { - supported_operation = false; - break; - } - } - - if (!supported_operation) { - EndScope(); -#ifdef _RELWITHDEBINFO - cout << current_pass << ": Warning: Unsupported atomic operation " << node->name << endl; -#endif - continue; - } - - unordered_map old_to_new_dim; - int new_idx = 0; - - for (int i = 0; i < dim; i++) { - if (!unused_dimensions.contains(i)) { - old_to_new_dim[i] = current_reduce->Index(new_idx++).node_; - } else { - old_to_new_dim[i] = Tensor::Constant(current_reduce->GetShape(), 0, current_reduce->GetFormat()).node_; - } - } - - map copied_node_map = CopyNodesWithIndex(index_nodes, old_to_new_dim); - Tensors new_indices = Tensors(); - new_indices.resize(indices.size()); - for (auto& [id, index] : indices) { - new_indices[id.second] = copied_node_map[index]->GetTensor(); - } - - //get the memory to scatter to - const Tensor* memory = node->args.Get(ArgType::Memory)->GetTensor(); - Tensor* store_op = nullptr; - Tensor* old_value = nullptr; - if(false) { //TODO (Moroz): check if indexes are one-to-one, otherwise must use atomic operations - old_value = &Tensor::Load(*memory, new_indices); - Tensor* new_value = nullptr; - if(node->name == "InterlockedAdd") { - new_value = &(*old_value + *current_reduce); - } else if (node->name == "InterlockedMin") { - new_value = &Tensor::min(*old_value, *current_reduce); - } else if (node->name == "InterlockedMax") { - new_value = &Tensor::max(*old_value, *current_reduce); - } else if (node->name == "InterlockedAnd") { - new_value = &(*old_value & *current_reduce); - } else if (node->name == "InterlockedOr") { - new_value = &(*old_value | *current_reduce); - } - store_op = &Tensor::Store(*memory, *new_value, new_indices); - } else { //still use atomic operations - store_op = &Tensor::MemoryOp(node->name, memory, new_indices, current_reduce); - } - - nodes_to_remove.push_back(node); - - EndScope(); - } - - for (auto node : nodes_to_remove) { - RemoveNode(node); - } - - UpdateGraph(); -} - -#define MIN_SPLIT_SIZE 1024 -#define SPLIT_DIM_SIZE 128 - -void IR::OptimizeReductions() { - vector reductions = GetNodesOfType(OpProp::Algorithm, OpProp::Reduction); - - vector nodes_to_remove; - for (auto node : reductions) { - int axis = (int)node->data[0]; - //get the input tensor - const Tensor* input = node->args.Get(ArgType::Input, 0)->GetTensor(); - //get the shape of the tensor at the reduction axis - const Tensor* tensor = input->node_->args.Get(ArgType::Shape, axis)->GetTensor(); - //try to get the constant value - int axis_value = tensor->TryGetConstant(); - if (axis_value < 0) { -#ifdef _RELWITHDEBINFO - cout << current_pass << ": Warning: Can not apply reduction optimization on non-constant axis" << endl; -#endif - continue; - } - //if size of the axis is less than the minimum split size, then do not split - if (axis_value < MIN_SPLIT_SIZE) { - continue; - } - - ExecuteExpressionAfter(node, [&]() { - //split dimension into smaller chunks and reduce them sequentially - const Tensor* split = &Tensor::SplitDim(*input, SPLIT_DIM_SIZE, axis); - Tensor* result = &Tensor::ReductionOP(node->name, *split, axis, false); - result = &Tensor::ReductionOP(node->name, *result, axis, false); - node->ReplaceThisWithGivenNode(result->node_); - nodes_to_remove.push_back(node); - }); - } - - for (auto node : nodes_to_remove) { - RemoveNode(node); - } - - UpdateGraph(); -} - -#define MAX_KERNEL_COPY_COST 50000.0f -bool IR::OptimizeKernels() { - // get kernel data - vector kernels = GetNodesOfType("kernel"); - ComputeNodeCost(); - - bool changed = false; - // go over each kernel and copy computations outside the kernel if they are - // cheap enough - for (auto kernel : kernels) { - ArgEdges args_to_copy; - ArgEdges shape_args_to_copy; - // go over all nodes in the kernel and check if their inputs can be copied - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - // go over all inputs - for (auto& [arg, from]: node->args.Inputs()) { - bool inside_kernel = from->HasParent(kernel); - bool from_in_kernel = from->HasParent("kernel"); - - if(from->flags.has(NodeProp::NoCopyFusion)) continue; - - if (!inside_kernel && !node->args.CannotCopyArgument(arg)) - { - // check if input is cheap enough to copy - float input_cost = from->cost_; - if (input_cost == -1.0) { -#ifdef _RELWITHDEBINFO - cout << current_pass << ": Warning: Could not determine cost of node " << from->name << endl; -#endif - continue; - } - bool cheap_enough = input_cost >= 0.0f && input_cost < MAX_KERNEL_COPY_COST; - bool has_only_one_output = from->args.OutputCount() == 1; - if (cheap_enough || has_only_one_output) { - args_to_copy.insert(ArgEdge(Arg(arg, from), *node)); - } - } - //shape arguments can not be inside kernels - if (from_in_kernel && arg.first == ArgType::Shape) { - shape_args_to_copy.insert(ArgEdge(Arg(arg, from), *node)); - } - } - } - -// auto kernel_deps_old = ComputeKernelDependencies(kernel); - - //go over kernel shape arguments - for (int i = 0; i < kernel->args.Count(ArgType::Shape); i++) { - Node* shape_node = kernel->args.Get(ArgType::Shape, i); - bool from_in_kernel =shape_node->HasParent("kernel"); - if (from_in_kernel) { - shape_args_to_copy.insert(ArgEdge(Arg(ArgID(ArgType::Shape, i), shape_node), kernel)); - } - } - - //copy the nodes that are outside the kernel inside - CopyArguments(args_to_copy, kernel->child); - -// auto kernel_deps = ComputeKernelDependencies(kernel); - -// //if more than allowed then do not apply changes -// if (kernel_deps.size() > max_kernel_memory_dependencies && kernel_deps_old.size() < kernel_deps.size() && !must_copy_all) { -// ClearChanges(); -// #ifdef _RELWITHDEBINFO -// std::cout << current_pass << ": Warning: Discarding kernel optimization changes for kernel " << kernel->name << " with " << kernel_deps.size() << " dependencies while before it had " << kernel_deps_old.size() << " dependencies" << endl; -// #endif -// } - - //copy shape arguments before the kernel - CopyArguments(shape_args_to_copy, kernel); - if (!args_to_copy.empty()) { - changed = true; - } - - ApplyChanges(false); - } - - return !changed; -} - -#define MAX_LOAD_COPY 3000.0f -#define MAX_LOAD_COPY_COUNT 2 -#define MAX_LOAD_SIZE_RATIO 0.5f - -bool IR::OptimizeKernelLoadOperations() { - ComputeNodeCost(); - - vector kernels = GetNodesOfType("kernel"); - - unordered_set nodes_to_remove; - - size_t loads_fused = 0; - - for (auto kernel : kernels) { - ShapeInfo kernel_shape = ShapeInfo(kernel); - - unordered_map loads_to_copy; - unordered_set memory_inputs; - // go over all nodes in the kernel and check if their inputs can be copied - for (auto node = NodeIterator(kernel); !node.end(); node.next()) { - if (node->name != "load") continue; - - if (node->flags.has(NodeProp::NoLoadFusion)) continue; - - //get memory input - Node* memory_input = node->args.Get(ArgType::Memory); - - if(memory_input->flags.has(NodeProp::StopFusion)) continue; - - ShapeInfo memory_shape = ShapeInfo(memory_input); - - bool inside_kernel = memory_input->HasParent("kernel"); - if (!inside_kernel) continue; - - bool is_not_modified = !memory_input->flags.has(NodeProp::Modified); - if (!is_not_modified) continue; - - float kernel_size = ShapeInfo::GetSizeEstimate(kernel_shape); - float memory_size = ShapeInfo::GetSizeEstimate(memory_shape); - float size_ratio = kernel_size / memory_size; - - int output_count = (int)memory_input->args.OutputCount(); - //only fuse if this is used less than MAX_LOAD_COPY_COUNT times or we can reduce dimensionality by fusing - bool fusion_makes_sense = (output_count < MAX_LOAD_COPY_COUNT) || - (size_ratio <= MAX_LOAD_SIZE_RATIO) || memory_size == 1.0f; - bool cheap_enough = memory_input->cost_ >= 0.0f && - memory_input->cost_ < (MAX_LOAD_COPY / output_count); - - - //if the memory input is used only once and is not a memory node - if (cheap_enough && fusion_makes_sense) { - loads_to_copy[memory_input] = *node; - memory_inputs.insert(memory_input); - } - } - - if (loads_to_copy.empty()) continue; - - //auto kernel_deps_old = ComputeKernelDependencies(kernel); - - for (auto load : loads_to_copy) { - //get the load - Node* memory_input = load.first; - Node* load_node = load.second; - - //get the indices - unordered_map indices; - for (auto& [arg, from] : load_node->args.Inputs()) { - if (arg.first == ArgType::Index) { - indices[arg.second] = from; - } - } - - //copy the load node - map copied_node_map = CopyNodesWithIndex({ memory_input }, indices, load_node); - - - Tensors indices_tensors = Tensors(); - indices_tensors.resize(indices.size()); - for (auto& [index, node] : indices) { - indices_tensors[index] = node->GetTensor(); - } - - //go over all the copied nodes and add load nodes to their inputs that are outside the kernel - for (auto& [old_node, new_node] : copied_node_map) { - AddNodeLoadOperations(new_node, kernel, indices_tensors); - } - - Node* copied_load = copied_node_map[memory_input]; - //copy over the information from the original load node - copied_load->CopyMetadata(load_node); - - map replacements; replacements[load_node] = copied_load; - - ReplaceArgs(load_node->args.Outputs(), replacements); - } - - //auto kernel_deps = ComputeKernelDependencies(kernel); - -// //if more than allowed then do not apply changes -// if (kernel_deps.size() > max_kernel_memory_dependencies && kernel_deps_old.size() < kernel_deps.size()) { -// ClearChanges(); -// #ifdef _RELWITHDEBINFO -// std::cout << current_pass << ": Warning: Discarding kernel load fusion changes for kernel " << kernel->name << " with " << kernel_deps.size() << " dependencies, before it had " << kernel_deps_old.size() << " dependencies" << endl; -// #endif -// continue; -// } - - //remove the load node since it is not needed anymore - for (auto load : loads_to_copy) { - nodes_to_remove.insert(load.second); - } - - loads_fused += loads_to_copy.size(); - ApplyChanges(false); - } - - // remove the load nodes - for (auto node : nodes_to_remove) { - RemoveNode(node); - } - - UpdateGraph(); - - return loads_fused == 0; -} - - - -#define MAX_HOST_COPY_COST 8192.0f - -void IR::OptimizeHost() { - ComputeNodeCost(); - - //loop over all nodes and copy their arguments if they are cheap enough and inside kernels - for (auto node = begin(); !node.end(); node.next()) { - if (node->HasParent("kernel")) { - continue; - } - - ArgEdges args_to_copy; - // go over all inputs - for (auto& [arg, from] : node->args.Inputs()) { - bool inside_kernel = from->HasParent("kernel"); - - if (inside_kernel && !node->args.CannotCopyArgument(arg)) { - // check if input is cheap enough to copy - float input_cost = from->cost_; - if (input_cost == -1.0) { - //throw std::runtime_error("Cost has not been computed for node " + input.from_->get()->var_name); - continue; - } - bool cheap_enough = input_cost >= 0.0f && input_cost < MAX_HOST_COPY_COST; - bool has_only_one_output = from->args.OutputCount() == 1; - - if (cheap_enough || has_only_one_output) { - args_to_copy.insert(ArgEdge(Arg(arg, from), *node)); - } else { - throw std::runtime_error("Host optimization: Copy cost too high for node " + node->name + " with cost " + to_string(input_cost)); - } - } - } - - CopyArguments(args_to_copy, node.get()); - ApplyChanges(false); - } -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Compiler/include/Compiler/Common.h b/TensorFrost/Compiler/include/Compiler/Common.h new file mode 100644 index 00000000..7d4179d2 --- /dev/null +++ b/TensorFrost/Compiler/include/Compiler/Common.h @@ -0,0 +1,148 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace TensorFrost { +extern "C" { + enum TFType { + Float, + Uint, + Int, + Bool, + None, + Unknown, + }; + + struct TFDataFormat { + TFType type; + size_t size; + + bool operator==(const TFDataFormat& other) const; + bool operator!=(const TFDataFormat& other) const; + int GetHash() const; + bool operator<(const TFDataFormat& other) const; + bool operator>(const TFDataFormat& other) const; + }; + +#define TFNone TFDataFormat{TFType::None, 0} +#define TFUnknown TFDataFormat{TFType::Unknown, 0} +#define TFBool TFDataFormat{TFType::Bool, 32} +#define TFFloat32 TFDataFormat{TFType::Float, 32} +#define TFInt32 TFDataFormat{TFType::Int, 32} +#define TFUint32 TFDataFormat{TFType::Uint, 32} +} + +// Utility class to automatically resize and set elements in a vector +template +class auto_vector : public std::vector { +public: + void set_element(size_t index, T&& value) { + if (index >= this->size()) { + this->resize(index + 1); + } + (*this)[index] = std::forward(value); + } +}; + +template> +struct VecHash { + size_t operator()(const std::vector& v) const noexcept { + size_t h = 0; + H hasher; + for (const auto& x : v) + h ^= hasher(x) + 0x9e3779b9 + (h << 6) + (h >> 2); + return h; + } +}; + +inline std::string ToString(const TFDataFormat& format) { + switch (format.type) { + case TFType::Float: return "float" + std::to_string(format.size); + case TFType::Uint: return "uint" + std::to_string(format.size); + case TFType::Int: return "int" + std::to_string(format.size); + case TFType::Bool: return "bool" + std::to_string(format.size); + case TFType::None: return "void"; + default: return "unknown"; + } +} + +using uint = unsigned int; + +struct Op; +struct OpBlock; +class OpBlockIterator; +struct ArgumentManager; +struct Argument; +class Value; +struct Shape; + +using Attribute = std::variant; +using AttributeMap = std::unordered_map; +using AttributeVector = std::vector; +using Values = std::vector; + +TFDataFormat GetTypeFromAttribute(const Attribute& attr); + +//ostringstream conversion for Attribute +inline std::ostream& operator<<(std::ostream& os, const Attribute& attr) { + std::visit([&os](const auto& v) { os << v; }, attr); + return os; +} + +inline std::string ToString(const Attribute& attr) { + std::ostringstream oss; + oss << attr; + return oss.str(); +} + +template +auto TransformVector(const Container& input, Func func) { + using T2 = decltype(func(*std::begin(input))); + std::vector output; + output.reserve(input.size()); + for (const auto& item : input) { + output.push_back(func(item)); + } + return output; +} + +template +auto ConcatVectors(const std::vector& a, const std::vector& b) { + std::vector result; + result.reserve(a.size() + b.size()); + result.insert(result.end(), a.begin(), a.end()); + result.insert(result.end(), b.begin(), b.end()); + return result; +} + +template +auto SliceVector(const std::vector& vec, size_t start, size_t end = -1) { + if (end == -1 || end > vec.size()) { + end = vec.size(); + } + if (start >= end || start >= vec.size()) { + return std::vector(); + } + return std::vector(vec.begin() + start, vec.begin() + end); +} +} + +namespace std { +template<> +struct hash { + size_t operator()(const TensorFrost::TFDataFormat& f) const noexcept { + return static_cast(f.GetHash()); + } +}; +} diff --git a/TensorFrost/Compiler/include/Compiler/ExecutionContext.h b/TensorFrost/Compiler/include/Compiler/ExecutionContext.h new file mode 100644 index 00000000..9cfa9e46 --- /dev/null +++ b/TensorFrost/Compiler/include/Compiler/ExecutionContext.h @@ -0,0 +1,30 @@ +#pragma once + +#include "Common.h" +#include "Operation.h" + +namespace TensorFrost { + +struct ExecutionContext { + std::unique_ptr base_block; + std::stack cursor_stack; + + ExecutionContext(); + void BeginCursor(OpBlock::Iterator it); + void EndCursor(); + + Op &Add(std::unique_ptr op); + Op &AddBeforeCursor(std::unique_ptr op); +}; + +void StartExecutionContext(ExecutionContext* ctx); +ExecutionContext* GetContext(); +OpBlock* GetBaseBlock(); +OpBlock* GetCurrentBlock(); +void BeginCursor(OpBlock::Iterator it); +void BeginCursor(OpBlock& block); +void BeginCursor(Op* op); +void EndCursor(); +void EndExecutionContext(); + +} diff --git a/TensorFrost/Compiler/include/Compiler/Operation.h b/TensorFrost/Compiler/include/Compiler/Operation.h new file mode 100644 index 00000000..a6c61ab5 --- /dev/null +++ b/TensorFrost/Compiler/include/Compiler/Operation.h @@ -0,0 +1,63 @@ +#pragma once + +#include "Common.h" +#include "OperationArguments.h" +#include "OperationBlocks.h" +#include "OperationRegistry.h" +#include "ExecutionContext.h" +#include "Overloads.h" + +namespace TensorFrost { + +struct Op { + std::string opcode; + std::unique_ptr args; + AttributeMap attributes; + TFDataFormat type; + std::vector> blocks; + int output_count = 1; + + OpBlock* parent_block = nullptr; + size_t index = 0; //might not be up to date + std::string varname; + + Op(std::string op_name); + OpBlock* NewBlock(); + OpBlock& GetBlock(int index = 0); + + void AddAttribute(const std::string& name, const Attribute& value); + void ChangeAttribute(const std::string& name, const Attribute& value); + + Attribute GetAttribute(const std::string &name) const; + + bool Compare(const Op& other) const; +}; + +void ApplyOpTransform(OpBlock& block, const std::function& transform); +void IterateOver(OpBlock &block, const std::function &transform); +std::set CollectDependencies(std::vector ops); + +template +State IterateWithLocalState(OpBlock &block, Fn &&f) { + State current = {}; + for (auto it = block.begin(); it.valid(); it.next()) { + std::vector kids; + for (auto &sb : it->blocks) + kids.push_back(IterateWithLocalState(*sb, f)); + f(it, current, kids); + } + return current; +} + +template +void IterateWithGlobalState(OpBlock &block, State& current, Fn &&f) { + for (auto it = block.begin(); it.valid(); it.next()) { + std::vector kids; + for (auto &sb : it->blocks) + kids.push_back(IterateWithGlobalState(*sb,current, f)); + f(it, current, kids); + } +} + +} // namespace ir + diff --git a/TensorFrost/Compiler/include/Compiler/OperationArguments.h b/TensorFrost/Compiler/include/Compiler/OperationArguments.h new file mode 100644 index 00000000..118f2a36 --- /dev/null +++ b/TensorFrost/Compiler/include/Compiler/OperationArguments.h @@ -0,0 +1,31 @@ +#pragma once +#include "Common.h" + +namespace TensorFrost { + +struct Argument { + Op* from = nullptr; + Op* to = nullptr; + int arg_index = 0; // Index in to's arguments + int from_index = 0; // Index of from's output + + Value From() const; +}; + +struct ArgumentManager { + Op* parent_op = nullptr; + auto_vector> inputs; + std::set> used_at; + + ArgumentManager(Op* parent); + void AddArgument(Value from, int arg_index = 0); + void SetAsOutput(Argument *arg); + void RemoveOutput(Argument *arg); + void SetArguments(Values args); + void Remove(int index); + void RemoveAll(); + + Values Inputs() const; +}; + +} diff --git a/TensorFrost/Compiler/include/Compiler/OperationBlocks.h b/TensorFrost/Compiler/include/Compiler/OperationBlocks.h new file mode 100644 index 00000000..cd0d4f34 --- /dev/null +++ b/TensorFrost/Compiler/include/Compiler/OperationBlocks.h @@ -0,0 +1,47 @@ +#pragma once +#include "Operation.h" + +namespace TensorFrost { +struct OpBlock { + using List = std::list>; + using It = List::iterator; + + Op* parent_op = nullptr; + List ops; + + OpBlock(Op* parent = nullptr); + + class Iterator { + OpBlock* parent_; + List* list_; + It cur_; + + public: + Iterator(OpBlock *parent, List::iterator it); + + Op* operator*() const; + Op* operator->() const; + Op* get_next() const; + Op* get_prev() const; + + Iterator& next(); + Iterator& prev(); + Iterator& insert_after(std::unique_ptr op); + Iterator& insert_before(std::unique_ptr op); + Iterator& remove(); + + Iterator& move_before(Iterator other); + Iterator& move_range_before(Iterator other_start, Iterator other_end); + + OpBlock* parent() const { return parent_; } + + bool valid() const; + bool operator==(const Iterator& o) const; + bool operator!=(const Iterator& o) const; + }; + + Iterator begin(); + Iterator end(); +}; + +} diff --git a/TensorFrost/Compiler/include/Compiler/OperationRegistry.h b/TensorFrost/Compiler/include/Compiler/OperationRegistry.h new file mode 100644 index 00000000..26f2289a --- /dev/null +++ b/TensorFrost/Compiler/include/Compiler/OperationRegistry.h @@ -0,0 +1,109 @@ +#pragma once +#include "Common.h" + +namespace TensorFrost { + +enum class OpClass { + Operator, + UnaryOperator, + Function, + Copy, + Keyword, + Parallel, + Variable, + TypeCast, + TypeReinterpret, + Constant, + TernaryOperator, + Memory, + Phi, + Set, + None, +}; + +enum class OpProp { + HasShape, + Load, + Store, + MemoryOp, + Set, +}; + +using FoldFn = std::function; +using CalcTuple = std::function; + +[[noreturn]] inline void bad_arity(std::size_t expect, std::size_t got) +{ + throw std::invalid_argument("constant-fold expects " + std::to_string(expect) + + " operands, got " + std::to_string(got)); +} + +template +FoldFn make_fold1(F f) { + return [f = std::move(f)](AttributeVector a) -> Attribute { + if (a.size() != 1) bad_arity(1, a.size()); + return std::visit([&](auto &&x) -> Attribute { + return f(std::forward(x)); + }, a[0]); + }; +} + +template +FoldFn make_fold2(F f) { + return [f = std::move(f)](AttributeVector a) -> Attribute { + if (a.size() != 2) bad_arity(2, a.size()); + return std::visit([&](auto &&x, auto &&y) -> Attribute { + return f(std::forward(x), std::forward(y)); + }, a[0], a[1]); + }; +} + +template +FoldFn make_fold3(F f) { + return [f = std::move(f)](AttributeVector a) -> Attribute { + if (a.size() != 3) bad_arity(3, a.size()); + return std::visit([&](auto &&x, auto &&y, auto &&z) -> Attribute { + return f(std::forward(x), std::forward(y), std::forward(z)); + }, a[0], a[1], a[2]); + }; +} + +enum class ArgProp { + IgnoreShape +}; + +struct ArgSpec { + char out; + std::vector in; + std::map> types; + std::map> props; + bool variadic = false; + + ArgSpec(std::string io, + std::map> types = {}, + std::map> props = {}); + + bool ValidArgCount(size_t count) const; + + bool IsValid(std::vector inputs, TFDataFormat output) const; + TFDataFormat EstimateOutputType(const std::vector& inputs) const; + std::vector> InputProperties(const std::vector& inputs) const; +}; + +struct OpSpec { + std::string name; + ArgSpec arg_spec; + OpClass op_class = OpClass::None; + std::set props; + int blocks = 0; + FoldFn const_fold = nullptr; + CalcTuple calc_tuple = nullptr; + + TFDataFormat GetOutputType(const std::vector& args) const; + bool IsValid(const std::vector& inputs, TFDataFormat output) const; +}; + +void RegisterOperation(const OpSpec& spec); +OpSpec* GetOpSpec(const std::string& name); + +} \ No newline at end of file diff --git a/TensorFrost/Compiler/include/Compiler/Overloads.h b/TensorFrost/Compiler/include/Compiler/Overloads.h new file mode 100644 index 00000000..e58d2cbe --- /dev/null +++ b/TensorFrost/Compiler/include/Compiler/Overloads.h @@ -0,0 +1,30 @@ +#pragma once +#include "Operation.h" +#include "Value.h" + +namespace TensorFrost { +std::pair create_op(std::string op, const Values& args, TFDataFormat output_type = TFUnknown); +Value value_op(std::string op, Values args = {}, TFDataFormat output_type = TFUnknown); +Values tuple_op(std::string op, Values args = {}, TFDataFormat output_type = TFUnknown); + +Value constant(int value); +Value constant(uint value); +Value constant(float value); +Value constant(bool value); + +void vmap(Values shape, std::function body); +Value memory(Values shape, TFDataFormat type); +Value load_at_index(Value mem, Values indices); +void if_cond(Value cond, std::function body_true, std::function body_false = nullptr); +Value loop(Value start, Value end, Value step, std::function body); +Value phi(Values inputs); + +inline Value toint(Value x) { return value_op("toint", {x}); } +inline Value tofloat(Value x) { return value_op("tofloat", {x}); } +inline Value touint(Value x) { return value_op("touint", {x}); } +inline Value tobool(Value x) { return value_op("tobool", {x}); } + +inline Value sin(Value x) { return value_op("sin", {x}); } +inline Value cos(Value x) { return value_op("cos", {x}); } +inline Value tan(Value x) { return value_op("tan", {x}); } +} diff --git a/TensorFrost/Compiler/include/Compiler/Printer.h b/TensorFrost/Compiler/include/Compiler/Printer.h new file mode 100644 index 00000000..4ba12306 --- /dev/null +++ b/TensorFrost/Compiler/include/Compiler/Printer.h @@ -0,0 +1,16 @@ +#pragma once +#include "Operation.h" +#include "OperationBlocks.h" + +namespace TensorFrost { + +std::string VariableName(const Op* op); +std::string PrintOp(const Op* op); +std::string PrintBlock(OpBlock& base_block); +void AssignVariableNames(OpBlock &block); +std::string PrintAttribute(Attribute attr); +std::string AddIndent(const std::string& str, int indent); +std::string PrintArray(std::vector items, const std::string& begin = "", const std::string& end = "", + const std::string& separator = ", "); + +} diff --git a/TensorFrost/Compiler/include/Compiler/TFProgram.h b/TensorFrost/Compiler/include/Compiler/TFProgram.h new file mode 100644 index 00000000..fed163df --- /dev/null +++ b/TensorFrost/Compiler/include/Compiler/TFProgram.h @@ -0,0 +1,22 @@ +#pragma once +#include "Operation.h" +#include "ExecutionContext.h" +#include "Printer.h" + +namespace TensorFrost { +class TFProgram { +public: + ExecutionContext context; + std::vector program_inputs; + std::vector program_outputs; + + TFProgram(std::function, std::vector>()> program_fn); + + void Compile(); + void ConstantFold(); + void RemoveUnused(); + void CombineVmapDepthwise(); + + std::string DebugPrint() const; +}; +} \ No newline at end of file diff --git a/TensorFrost/Compiler/include/Compiler/Value.h b/TensorFrost/Compiler/include/Compiler/Value.h new file mode 100644 index 00000000..11516113 --- /dev/null +++ b/TensorFrost/Compiler/include/Compiler/Value.h @@ -0,0 +1,68 @@ +#pragma once +#include "Operation.h" + +namespace TensorFrost { + +// Op thin wrapper class for overloaded mathematics and manipulations +class Value { +public: + Op* op = nullptr; + int out_index = 0; // Index of the output value in the operation + + Value() = default; + Value(Op* operation, int from_index = 0); + Value(const Op* operation, int from_index = 0); + Value(float value); + Value(int value); + Value(uint value); + Value(bool value); + Value(const Value& other); + + // indexed access + Value operator[](const Values& indices) const; + + // binary ops take const ref and are const themselves + Value operator+(const Value& other) const; + Value operator-(const Value& other) const; + Value operator*(const Value& other) const; + Value operator/(const Value& other) const; + Value operator%(const Value& other) const; + Value operator==(const Value& other) const; + Value operator!=(const Value& other) const; + Value operator<(const Value& other) const; + Value operator<=(const Value& other) const; + Value operator>(const Value& other) const; + Value operator>=(const Value& other) const; + Value operator&&(const Value& other) const; + Value operator||(const Value& other) const; + Value operator<<(const Value& other) const; + Value operator>>(const Value& other) const; + + // unary ops + Value operator!() const; + Value operator-() const; + Value operator+() const; + Value operator~() const; + + bool Compare(const Value& other) const; + + void Set(Value value); +}; + +std::vector values_to_ops(const Values& values); +Values ops_to_values(const std::vector& ops); + +struct Shape { + Values dimensions; + Shape(Values dims) : dimensions(std::move(dims)) {} + Shape(std::initializer_list dims) : dimensions(dims) {} + Shape() = default; + Shape(const Shape& other) : dimensions(other.dimensions) {} + + void AddDimension(const Value& dim); + void AddDimensions(const Values& dims); + bool Broadcastable(const Shape& other) const; +}; + +Shape ComputeShape(Value x); +} diff --git a/TensorFrost/Compiler/src/Common.cpp b/TensorFrost/Compiler/src/Common.cpp new file mode 100644 index 00000000..87b6dace --- /dev/null +++ b/TensorFrost/Compiler/src/Common.cpp @@ -0,0 +1,36 @@ +#include "Compiler/Common.h" + +namespace TensorFrost { +bool TFDataFormat::operator==(const TFDataFormat &other) const { + return type == other.type && size == other.size; +} + +bool TFDataFormat::operator!=(const TFDataFormat &other) const { + return !(*this == other); +} + +int TFDataFormat::GetHash() const { + return (int)type << 16 | (int)size; +} + +bool TFDataFormat::operator<(const TFDataFormat &other) const { + return GetHash() < other.GetHash(); +} + +bool TFDataFormat::operator>(const TFDataFormat &other) const { + return GetHash() > other.GetHash(); +} + +TFDataFormat GetTypeFromAttribute(const Attribute& attr) { + if (std::holds_alternative(attr)) { + return TFInt32; + } else if (std::holds_alternative(attr)) { + return TFUint32; + } else if (std::holds_alternative(attr)) { + return TFFloat32; + } else if (std::holds_alternative(attr)) { + return TFBool; + } + throw std::runtime_error("Unsupported attribute type for TFDataFormat conversion"); +} +} diff --git a/TensorFrost/Compiler/src/ExecutionContext.cpp b/TensorFrost/Compiler/src/ExecutionContext.cpp new file mode 100644 index 00000000..6a5b2c4d --- /dev/null +++ b/TensorFrost/Compiler/src/ExecutionContext.cpp @@ -0,0 +1,95 @@ +#include "Compiler/ExecutionContext.h" +#include "Compiler/Operation.h" +#include "Compiler/OperationBlocks.h" + +namespace TensorFrost { +ExecutionContext::ExecutionContext(): base_block(std::make_unique()) { + cursor_stack.push(base_block->begin()); +} + +void ExecutionContext::BeginCursor(OpBlock::Iterator it) { + cursor_stack.push(it); +} + +void ExecutionContext::EndCursor() { + if (cursor_stack.empty()) { + throw std::runtime_error("This is the last cursor, cannot end it"); + } + cursor_stack.pop(); +} + +Op& ExecutionContext::Add(std::unique_ptr op) { + cursor_stack.top().insert_before(std::move(op)); + Op* new_op = *cursor_stack.top(); + cursor_stack.top().next(); // Move the cursor to the new op + return *new_op; +} + +Op& ExecutionContext::AddBeforeCursor(std::unique_ptr op) { + cursor_stack.top().insert_before(std::move(op)); + return **cursor_stack.top(); +} + +ExecutionContext* current_context = nullptr; + +void StartExecutionContext(ExecutionContext* ctx) { + if (current_context) { + throw std::runtime_error("Execution context already started"); + } + if (!ctx) { + throw std::invalid_argument("Execution context cannot be null"); + } + current_context = ctx; +} + +ExecutionContext* GetContext() { + return current_context; +} + +OpBlock* GetBaseBlock() { + if (!current_context) { + throw std::runtime_error("No execution context available"); + } + return current_context->base_block.get(); +} + +OpBlock* GetCurrentBlock() { + if (!current_context) { + throw std::runtime_error("No execution context available"); + } + return current_context->cursor_stack.top().parent(); +} + +void BeginCursor(OpBlock::Iterator it) { + GetContext()->BeginCursor(it); +} + +void BeginCursor(OpBlock& block) { + GetContext()->BeginCursor(block.begin()); +} + +void BeginCursor(Op* op) { + if (!op || !op->parent_block) { + throw std::runtime_error("Op does not belong to a block"); + } + OpBlock::Iterator it(op->parent_block, op->parent_block->ops.begin()); + // Find the iterator for the specific op + for (; it.valid(); it.next()) { + if (*it == op) { + GetContext()->BeginCursor(it); + return; + } + } +} + +void EndCursor() { + GetContext()->EndCursor(); +} + +void EndExecutionContext() { + if (!current_context) { + throw std::runtime_error("No execution context to end"); + } + current_context = nullptr; +} +} \ No newline at end of file diff --git a/TensorFrost/Compiler/src/Operation.cpp b/TensorFrost/Compiler/src/Operation.cpp new file mode 100644 index 00000000..9785fcd4 --- /dev/null +++ b/TensorFrost/Compiler/src/Operation.cpp @@ -0,0 +1,96 @@ +#include "Compiler/Operation.h" + +namespace TensorFrost { +Op::Op(std::string op_name): opcode(std::move(op_name)) { + args = std::make_unique(this); + type = TFNone; +} + +OpBlock* Op::NewBlock() { + blocks.emplace_back(std::make_unique(this)); + return blocks.back().get(); +} + +OpBlock& Op::GetBlock(int index) { + if (index < 0 || index >= blocks.size()) { + throw std::out_of_range("Block index out of range"); + } + return *blocks[index]; +} + +void Op::AddAttribute(const std::string &name, const Attribute &value) { + if (attributes.find(name) != attributes.end()) { + throw std::runtime_error("Attribute '" + name + "' already exists in operation '" + opcode + "'"); + } + attributes[name] = value; +} + +void Op::ChangeAttribute(const std::string &name, const Attribute &value) { + if (attributes.find(name) == attributes.end()) { + throw std::runtime_error("Attribute '" + name + "' not found in operation '" + opcode + "'"); + } + attributes[name] = value; +} + +Attribute Op::GetAttribute(const std::string &name) const { + auto it = attributes.find(name); + if (it == attributes.end()) { + throw std::runtime_error("Attribute '" + name + "' not found in operation '" + opcode + "'"); + } + return it->second; +} + +bool Op::Compare(const Op &other) const { + bool both_const = (opcode == "const" && other.opcode == "const"); + if (both_const) { + // Compare constant values directly + Attribute this_value = GetAttribute("value"); + Attribute other_value = other.GetAttribute("value"); + return (this_value == other_value); + } + return false; // TODO: Implement more complex comparison logic for non-constant operations +} + +void ApplyOpTransform(OpBlock &block, const std::function &transform) { + for (auto& op : block.ops) { + for (auto& sub_block : op->blocks) { + ApplyOpTransform(*sub_block, transform); + } + transform(*op); + } +} + +void IterateOver(OpBlock &block, const std::function &transform) { + for (OpBlock::Iterator it = block.begin(); it.valid(); it.next()) { + for (auto& sub_block : it->blocks) { + IterateOver(*sub_block, transform); + } + transform(it); + } +} + +void ReverseIterateOver(OpBlock &block, const std::function &transform) { + for (OpBlock::Iterator it = block.end(); it.valid(); it.prev()) { + for (auto& sub_block : it->blocks) { + ReverseIterateOver(*sub_block, transform); + } + transform(it); + } +} + +std::set CollectDependencies(std::vector ops) { + std::set dependencies; + std::function collect_dependencies = [&](Op* op) { + if (op == nullptr || dependencies.contains(op)) return; // Already processed + dependencies.insert(op); + for (auto& input : op->args->inputs) { + collect_dependencies(input->from); + } + collect_dependencies(op->parent_block->parent_op); // Parent depends on this operation + }; + for (Op* op : ops) { + collect_dependencies(op); + } + return dependencies; +} +} diff --git a/TensorFrost/Compiler/src/OperationArguments.cpp b/TensorFrost/Compiler/src/OperationArguments.cpp new file mode 100644 index 00000000..719c7f59 --- /dev/null +++ b/TensorFrost/Compiler/src/OperationArguments.cpp @@ -0,0 +1,55 @@ +#include "Compiler/Operation.h" + +namespace TensorFrost { + +Value Argument::From() const { + return Value(from, from_index); +} + +ArgumentManager::ArgumentManager(Op *parent): parent_op(parent) { +} + +void ArgumentManager::AddArgument(Value from, int arg_index) { + inputs.set_element(arg_index, std::make_unique(Argument{from.op, parent_op, arg_index, from.out_index})); + from.op->args->SetAsOutput(inputs[arg_index].get()); +} + +void ArgumentManager::SetAsOutput(Argument *arg) { + used_at.insert({arg->arg_index, arg}); +} + +void ArgumentManager::RemoveOutput(Argument *arg) { + used_at.erase({arg->arg_index, arg}); +} + +void ArgumentManager::SetArguments(Values args) { + for (size_t i = 0; i < args.size(); ++i) { + AddArgument(args[i], (int)i); + } +} + +void ArgumentManager::Remove(int index) { + if (index < 0 || index >= inputs.size()) { + throw std::out_of_range("Index out of range for arguments"); + } + Argument *arg = inputs[index].get(); + if(!arg || !arg->from) throw std::runtime_error("Invalid argument"); + arg->from->args->RemoveOutput(arg); + inputs[index].reset(); +} + +void ArgumentManager::RemoveAll() { + for (size_t i = 0; i < inputs.size(); ++i) { + Remove((int)i); + } + inputs.clear(); +} + +Values ArgumentManager::Inputs() const { + Values result; + for (const auto& arg : inputs) { + if (arg) result.push_back(arg->From()); + } + return result; +} +} diff --git a/TensorFrost/Compiler/src/OperationBlocks.cpp b/TensorFrost/Compiler/src/OperationBlocks.cpp new file mode 100644 index 00000000..ebc7cfc3 --- /dev/null +++ b/TensorFrost/Compiler/src/OperationBlocks.cpp @@ -0,0 +1,71 @@ +#include "Compiler/Operation.h" + +namespace TensorFrost { +OpBlock::OpBlock(Op *parent): parent_op(parent) {} + +OpBlock::Iterator::Iterator(OpBlock *parent, List::iterator it) + : parent_(parent), list_(&parent->ops), cur_(it) {} + +Op* OpBlock::Iterator::operator*() const { return cur_->get(); } +Op* OpBlock::Iterator::operator->() const { return cur_->get(); } + +Op * OpBlock::Iterator::get_next() const { return cur_->get(); } +Op * OpBlock::Iterator::get_prev() const { + if (cur_ == list_->begin()) return nullptr; + return std::prev(cur_)->get(); +} + +OpBlock::Iterator & OpBlock::Iterator::next() { if (cur_ != list_->end()) ++cur_; return *this; } +OpBlock::Iterator & OpBlock::Iterator::prev() { if (cur_ != list_->begin()) --cur_; return *this; } + +OpBlock::Iterator& OpBlock::Iterator::insert_after(std::unique_ptr op) { + if (op->parent_block) throw std::runtime_error("Op already belongs to a block"); + + auto pos = (cur_ == list_->end()) ? list_->end() : std::next(cur_); + cur_ = list_->insert(pos, std::move(op)); // <- after-cursor + cur_->get()->parent_block = parent_; + return *this; +} + +OpBlock::Iterator& OpBlock::Iterator::insert_before(std::unique_ptr op) { + if (op->parent_block) throw std::runtime_error("Op already belongs to a block"); + + cur_ = list_->insert(cur_, std::move(op)); // <- before-cursor + cur_->get()->parent_block = parent_; + return *this; +} + +OpBlock::Iterator& OpBlock::Iterator::remove() { + if (cur_ == list_->end()) return *this; // Nothing to remove + cur_->get()->parent_block = nullptr; // Clear parent block reference + cur_ = list_->erase(cur_); // Remove and update iterator + return *this; +} + +OpBlock::Iterator& OpBlock::Iterator::move_before(Iterator other) +{ + auto pos = cur_; + list_->splice(pos, *other.list_, other.cur_); + cur_ = std::prev(pos); + cur_->get()->parent_block = parent_; + return *this; +} + +OpBlock::Iterator& OpBlock::Iterator::move_range_before(Iterator other_start, Iterator other_end) +{ + auto pos = cur_; + for(auto it = other_start.cur_; it != other_end.cur_; ++it) + it->get()->parent_block = parent_; + list_->splice(pos, *other_start.list_, other_start.cur_, other_end.cur_); + cur_ = std::prev(pos); + return *this; +} + +bool OpBlock::Iterator::valid() const { return cur_ != list_->end(); } +bool OpBlock::Iterator::operator==(const Iterator &o) const { return cur_ == o.cur_; } +bool OpBlock::Iterator::operator!=(const Iterator &o) const { return cur_ != o.cur_; } + +OpBlock::Iterator OpBlock::begin() { return Iterator(this, ops.begin()); } +OpBlock::Iterator OpBlock::end() { return Iterator(this, ops.end()); } + +} diff --git a/TensorFrost/Compiler/src/OperationRegistry.cpp b/TensorFrost/Compiler/src/OperationRegistry.cpp new file mode 100644 index 00000000..5e00db8c --- /dev/null +++ b/TensorFrost/Compiler/src/OperationRegistry.cpp @@ -0,0 +1,254 @@ +#include +#include + +#include "Compiler/Operation.h" + +using namespace std; + +namespace TensorFrost { +ArgSpec::ArgSpec(std::string io, std::map> types, + std::map> props) { + this->props = std::move(props); + this->types = std::move(types); + if (io.empty()) { + throw std::invalid_argument("Argument specification cannot be empty"); + } + // Parse argument specification (e.g., "x(x,y,t)" -> out = 'x', in = {'x', 'y', 't'}, or "z(y,z,...)" -> out = 'z', in = {'y', 'z'}, variadic = true) + out = io[0]; + //get substring between parentheses + size_t start = io.find('('); + size_t end = io.find(')', start); + //split the substring by commas + std::stringstream ss(io.substr(start + 1, end - start - 1)); + std::string token; + while (std::getline(ss, token, ',')) { + if (token == "...") { + variadic = true; + break; // Variadic argument, stop parsing further + } + if (token.empty()) continue; // Skip empty tokens + in.push_back(token[0]); + } +} + +bool ArgSpec::ValidArgCount(size_t count) const { + return variadic ? count >= in.size() - 1 : count == in.size(); +} + +bool ArgSpec::IsValid(std::vector inputs, TFDataFormat output) const { + bool valid_count = ValidArgCount(inputs.size()); + if (!valid_count) return false; + + auto name_of = [&](size_t i) -> const char& { + return variadic ? in.front() : in[i]; + }; + + std::unordered_map seen; + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& n = name_of(i); + if (seen.count(n) && !(seen[n] == inputs[i])) + return false; // Conflicting types for arg + seen[n] = inputs[i]; + + auto a = types.find(n); + if (a != types.end() && !a->second.count(inputs[i])) + return false; + } + + auto ao = types.find(out); + if (ao != types.end() && !ao->second.count(output)) return false; + if (seen.count(out) && !(seen[out] == output)) return false; + + if (variadic) { + for (size_t i = 1; i < inputs.size(); ++i) + if (!(inputs[i] == inputs[0])) return false; + } + return true; +} + +TFDataFormat ArgSpec::EstimateOutputType(const std::vector &inputs) const { + bool valid_count = ValidArgCount(inputs.size()); + if (!valid_count) return TFUnknown; + + auto name_of = [&](size_t i) -> const char& { + return variadic ? in.front() : in[i]; + }; + + std::unordered_map seen; + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& n = name_of(i); + if (seen.count(n) && !(seen[n] == inputs[i])) + return TFUnknown; // Conflicting types for arg + seen[n] = inputs[i]; + } + + if (seen.count(out)) return seen[out]; + + auto ao = types.find(out); + if (ao != types.end() && ao->second.size() == 1) + return *ao->second.begin(); + + return TFUnknown; +} + +std::vector> ArgSpec::InputProperties(const std::vector &inputs) const { + if ((!variadic && inputs.size() != in.size()) || + (variadic && (in.empty() || inputs.empty()))) + return {}; + + auto name_of = [&](size_t i) -> const char& { + return variadic ? in.front() : in[i]; + }; + + std::vector> out; + out.reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& n = name_of(i); + auto p = props.find(n); + out.push_back(p == props.end() ? std::set() : p->second); + } + return out; +} + +TFDataFormat OpSpec::GetOutputType(const std::vector &args) const { + TFDataFormat ret = arg_spec.EstimateOutputType(args); + return ret; +} + +bool OpSpec::IsValid(const std::vector& inputs, TFDataFormat output) const { + return arg_spec.IsValid(inputs, output); +} + +#define BIN_OP_FOLD(op) \ +make_fold2([](auto a, auto b) { \ + if constexpr (std::is_same_v, bool> || std::is_same_v, bool>) { \ + return static_cast(a) op static_cast(b); \ + } else { \ + return a op b; \ + } \ +}) + +#define UN_OP_FOLD(op) \ + make_fold1([](auto a) { return op a; }) + +#define UN_FUNC_FOLD(op) \ + make_fold1([](auto a) { return op(a); }) + +namespace { + +template +auto promote_for_compare(T&& value) { + using Decayed = std::decay_t; + if constexpr (std::is_same_v) { + return static_cast(value); + } else { + return static_cast(value); + } +} + +template +FoldFn make_compare_fold_impl(Compare cmp) { + return make_fold2([cmp = std::move(cmp)](auto a, auto b) -> Attribute { + auto lhs = promote_for_compare(std::forward(a)); + auto rhs = promote_for_compare(std::forward(b)); + using Common = std::common_type_t; + return (cmp)(static_cast(lhs), static_cast(rhs)); + }); +} + +} // namespace + +#define BIN_FUNC_FOLD(op) \ + make_compare_fold_impl(op) + +#define TERN_FUNC_FOLD(op) \ + make_fold3([](auto a, auto b, auto c) { return op(a, b, c); }) + +#define DEF_OP(op_name, overload_str, operation_class, ...) \ + OpSpec{ .name = op_name, .arg_spec = ArgSpec overload_str, .op_class = operation_class, __VA_ARGS__ } + +vector default_operations = { + DEF_OP("memory", ("x(y,...)"), OpClass::Memory, .props = {OpProp::HasShape}), + DEF_OP("load", ("x(x,y,...)", {{'y', {TFInt32}}}, {{'x', {ArgProp::IgnoreShape}}}), + OpClass::Function, .props = {OpProp::Load, OpProp::MemoryOp}), + DEF_OP("store", ("x(x,x,y,...)", {{'y', {TFInt32}}}, {{'x', {ArgProp::IgnoreShape}}}), + OpClass::Function, .props = {OpProp::Store, OpProp::MemoryOp}), + + DEF_OP("const", ("x()"), OpClass::Constant), + DEF_OP("copy", ("x(x)"), OpClass::Copy), + DEF_OP("set", ("x(x,x)"), OpClass::Set), + DEF_OP("add", ("x(x,x)"), OpClass::Operator, + .const_fold = BIN_OP_FOLD(+)), + DEF_OP("sub", ("x(x,x)"), OpClass::Operator, + .const_fold = BIN_OP_FOLD(-)), + DEF_OP("mul", ("x(x,x)"), OpClass::Operator, + .const_fold = BIN_OP_FOLD(*)), + DEF_OP("div", ("x(x,x)"), OpClass::Operator, + .const_fold = BIN_OP_FOLD(/)), + DEF_OP("sin", ("x(x)"), OpClass::UnaryOperator, + .const_fold = UN_FUNC_FOLD(std::sinf)), + DEF_OP("cos", ("x(x)"), OpClass::UnaryOperator, + .const_fold = UN_FUNC_FOLD(std::cosf)), + DEF_OP("tan", ("x(x)"), OpClass::UnaryOperator, + .const_fold = UN_FUNC_FOLD(std::tanf)), + + DEF_OP("eq", ("x(y,y)", {{'x', {TFBool}}}), OpClass::Operator, + .const_fold = BIN_FUNC_FOLD(std::equal_to<>())), + DEF_OP("ne", ("x(y,y)", {{'x', {TFBool}}}), OpClass::Operator, + .const_fold = BIN_FUNC_FOLD(std::not_equal_to<>())), + DEF_OP("lt", ("x(y,y)", {{'x', {TFBool}}}), OpClass::Operator, + .const_fold = BIN_FUNC_FOLD(std::less<>())), + DEF_OP("le", ("x(y,y)", {{'x', {TFBool}}}), OpClass::Operator, + .const_fold = BIN_FUNC_FOLD(std::less_equal<>())), + DEF_OP("gt", ("x(y,y)", {{'x', {TFBool}}}), OpClass::Operator, + .const_fold = BIN_FUNC_FOLD(std::greater<>())), + DEF_OP("ge", ("x(y,y)", {{'x', {TFBool}}}), OpClass::Operator, + .const_fold = BIN_FUNC_FOLD(std::greater_equal<>())), + + DEF_OP("tofloat", ("x(y)", {{'x', {TFFloat32}}}), OpClass::Function, + .const_fold = UN_FUNC_FOLD(static_cast)), + DEF_OP("toint", ("x(y)", {{'x', {TFInt32}}}), OpClass::Function, + .const_fold = UN_FUNC_FOLD(static_cast)), + DEF_OP("touint", ("x(y)", {{'x', {TFUint32}}}), OpClass::Function, + .const_fold = UN_FUNC_FOLD(static_cast)), + DEF_OP("tobool", ("x(y)", {{'x', {TFBool}}}), OpClass::Function, + .const_fold = UN_FUNC_FOLD(static_cast)), + + // Operations with blocks + DEF_OP("vmap", ("x(x,...)", {{'x', {TFInt32}}}), OpClass::Parallel, .props = {OpProp::HasShape}, .blocks = 1, + .calc_tuple = [](Op* op, Values args) -> Values { + Values tuple; + op->output_count = args.size(); + for (size_t i = 0; i < args.size(); ++i) tuple.push_back(Value(op, i)); + return tuple; + }), + DEF_OP("if_cond", ("x(y)", {{'x', {TFNone}}, {'y', {TFBool}}}), OpClass::Function, .blocks = 2), + DEF_OP("loop", ("x(x,x,x)", {{'x', {TFInt32}}}), OpClass::Function, .blocks = 1), + + DEF_OP("phi", ("x(x,...)"), OpClass::Phi), +}; + +std::unordered_map> CreateOperationRegistry() { + std::unordered_map> registry; + for (const auto& op : default_operations) { + registry[op.name] = std::make_unique(op); + } + return registry; +} + +std::unordered_map> operation_registry = CreateOperationRegistry(); + +void TensorFrost::RegisterOperation(const OpSpec &spec) { + if (operation_registry.contains(spec.name)) { + throw std::runtime_error("Operation already registered: " + spec.name); + } + operation_registry[spec.name] = std::make_unique(spec); +} + +OpSpec* TensorFrost::GetOpSpec(const std::string &name) { + if (!operation_registry.contains(name)) { + throw std::runtime_error("Operation not found: " + name); + } + return operation_registry[name].get(); +} +} \ No newline at end of file diff --git a/TensorFrost/Compiler/src/Overloads.cpp b/TensorFrost/Compiler/src/Overloads.cpp new file mode 100644 index 00000000..eef5a224 --- /dev/null +++ b/TensorFrost/Compiler/src/Overloads.cpp @@ -0,0 +1,117 @@ +#include "Compiler/Operation.h" +#include "Compiler/ExecutionContext.h" +#include "Compiler/Value.h" +#include "Compiler/Printer.h" + +using namespace std; + +namespace TensorFrost { + +// General function to create an Op instance in the current execution context +std::pair create_op(std::string op, const Values& args, TFDataFormat output_type) { + OpSpec* spec = GetOpSpec(op); + vector arg_types; + for (const auto& arg : args) { + arg_types.push_back(arg.op->type); + } + if (output_type == TFUnknown) { + output_type = spec->GetOutputType(arg_types); + } + // if (output_type == TFUnknown) { + // throw std::runtime_error("Cannot determine output type for operation '" + op + "'"); + // } + Op* op_instance = new Op(op); + op_instance->type = output_type; + op_instance->args->SetArguments(args); + + // Create blocks + for (int i = 0; i < spec->blocks; ++i) { + op_instance->NewBlock(); + } + op_instance = &GetContext()->Add(std::unique_ptr(op_instance)); + bool valid = spec->IsValid(arg_types, output_type); + if (!valid) { + throw std::runtime_error("Invalid operation types for '" + op + "' with arguments: " + + PrintArray(TransformVector(values_to_ops(args), PrintOp), "[", "]", ", \n")); + } + + // Check if the shape is compatible + Shape shape = ComputeShape(Value(op_instance)); + std::vector> input_props = spec->arg_spec.InputProperties(arg_types); + for (size_t i = 0; i < args.size(); ++i) { + if (input_props[i].contains(ArgProp::IgnoreShape)) continue; + if (!shape.Broadcastable(ComputeShape(args[i]))) { + throw std::runtime_error("Incompatible shape for argument " + to_string(i) + " in operation '" + op + "'"); + } + } + + return {op_instance, spec}; +} + +Value value_op(std::string op, Values args, TFDataFormat output_type) { + auto [op_instance, spec] = create_op(op, args, output_type); + if (spec->calc_tuple) throw std::runtime_error("Make op only creates single output operations, use calc_tuple for multi-output ops"); + return Value(op_instance); +} + +Values tuple_op(std::string op, Values args, TFDataFormat output_type) { + auto [op_instance, spec] = create_op(op, args, output_type); + if (!spec->calc_tuple) throw std::runtime_error("Make tuple op only works for operations with multiple outputs"); + return spec->calc_tuple(op_instance, args); +} + +Value constant(Attribute value) { + Value const_op = value_op("const", {}, GetTypeFromAttribute(value)); + const_op.op->attributes["value"] = value; + return const_op; +} + +Value constant(int value) { return constant(Attribute(value)); } +Value constant(uint value) { return constant(Attribute(value)); } +Value constant(float value) { return constant(Attribute(value)); } +Value constant(bool value) { return constant(Attribute(value)); } + +Value get_output(Value x, int index) { + return Value(x.op, index); +} + +void vmap(Values shape, std::function body) { + Values indices = tuple_op("vmap", shape); + GetContext()->BeginCursor(indices[0].op->GetBlock().begin()); + body(indices); + GetContext()->EndCursor(); +} + +Value memory(Values shape, TFDataFormat type) { + return value_op("memory", std::move(shape), type); +} + +Value load_at_index(Value mem, Values indices) { + return value_op("load", ConcatVectors({mem}, indices)); +} + +void if_cond(Value cond, std::function body_true, std::function body_false) { + Value if_op = value_op("if_cond", {cond}); + GetContext()->BeginCursor(if_op.op->GetBlock(0).begin()); + body_true(); + GetContext()->EndCursor(); + if (body_false) { + GetContext()->BeginCursor(if_op.op->GetBlock(1).begin()); + body_false(); + GetContext()->EndCursor(); + } +} + +Value loop(Value start, Value end, Value step, std::function body) { + Value loop_op = value_op("loop", {start, end, step}); + GetContext()->BeginCursor(loop_op.op->GetBlock().begin()); + body(loop_op); + GetContext()->EndCursor(); + return loop_op; +} + +Value phi(Values inputs) { + return value_op("phi", inputs); +} + +} diff --git a/TensorFrost/Compiler/src/Printer.cpp b/TensorFrost/Compiler/src/Printer.cpp new file mode 100644 index 00000000..39a55c50 --- /dev/null +++ b/TensorFrost/Compiler/src/Printer.cpp @@ -0,0 +1,112 @@ +#include "Compiler/Operation.h" +#include "Compiler/ExecutionContext.h" +#include "Compiler/Printer.h" + +using namespace std; + +namespace TensorFrost { + +std::string VariableName(const Op* op) { + if (op->opcode == "const") { + return ToString(op->attributes.at("value")); + } + return op->varname; +} + +std::vector StringifyArguments(const auto_vector>& vec) { + return TransformVector(vec, [](const std::unique_ptr& arg) { + return VariableName(arg->from) + (arg->from->output_count > 1 ? "[" + std::to_string(arg->from_index) + "]" : ""); + }); +} + +std::string PrintArray(std::vector items, const std::string &begin, const std::string &end, const std::string& separator) { + std::ostringstream oss; + if (items.empty()) return ""; + oss << begin; + bool first = true; + for (const auto& item : items) { + if (item.empty()) continue; // Skip empty items + if (!first) oss << separator; + first = false; + oss << item; + } + oss << end; + return oss.str(); +} + +std::string PrintArguments(const auto_vector>& vec, string begin, string end) { + return PrintArray(StringifyArguments(vec), begin, end); +} + +std::string PrintShape(const Shape& shape) { + std::vector dims; + for (const auto& dim : shape.dimensions) { + dims.push_back(VariableName(dim.op)); + } + return PrintArray(dims, "[", "]", ", "); +} + +std::string PrintAttribute(Attribute attr) { + std::ostringstream oss; + std::visit([&oss](const auto& v) { oss << v; }, attr); + return oss.str(); +} + +std::string PrintOp(const Op* op) { + std::ostringstream os; + os << ToString(op->type) << " " << op->varname; + if (op->opcode == "const") { + //return ""; + os << " = " << op->attributes.at("value"); + } else { + std::string inputs = PrintArguments(op->args->inputs, "", ""); + std::vector attributes; + for (const auto& [key, value] : op->attributes) { + attributes.push_back(key + ": " + PrintAttribute(value)); + } + std::string attributes_str = PrintArray(attributes, "{", "}"); + + std::string shape_str = "";// PrintShape(ComputeShape(Value(op))); + + os << shape_str << (op->output_count > 1 ? "[" + std::to_string(op->output_count) + "]" : ""); + os << " = " << op->opcode << "(" << PrintArray({inputs, attributes_str}) << ")"; + } + return os.str(); +} + +std::string AddIndent(const std::string& str, int indent) { + // Add indentation to each line of the string + std::string indented; + std::istringstream iss(str); + std::string line; + while (std::getline(iss, line)) { + indented += std::string(indent, ' ') + line + "\n"; + } + return indented; +} + +std::string PrintBlock(OpBlock &root) { + return IterateWithLocalState(root, [](OpBlock::Iterator &it, std::string& current, const std::vector &kids) { + std::string result = PrintOp(*it); + if(result.empty()) return; + if (!kids.empty()) { + std::vector indented; + indented.reserve(kids.size()); + for (auto &s : kids) indented.push_back(AddIndent(s, 4)); + result += PrintArray(indented, " { \n", "}", "} else { \n"); + } + result += '\n'; + current += result; + }); +} + +void AssignVariableNames(OpBlock &block) { + ApplyOpTransform(block, [](Op &op) { + static size_t var_counter = 0; + size_t index = var_counter++; + op.varname = "var" + std::to_string(index); + op.index = index; + }); +} + +} diff --git a/TensorFrost/Compiler/src/TFProgram.cpp b/TensorFrost/Compiler/src/TFProgram.cpp new file mode 100644 index 00000000..988349a3 --- /dev/null +++ b/TensorFrost/Compiler/src/TFProgram.cpp @@ -0,0 +1,66 @@ +#include "Compiler/TFProgram.h" + +namespace TensorFrost { +TFProgram::TFProgram(std::function()> program_fn) { + StartExecutionContext(&context); + + auto [ins, outs] = program_fn(); + program_inputs = std::move(ins); + program_outputs = std::move(outs); + if (program_outputs.empty()) { + throw std::runtime_error("Program must have at least one output"); + } + EndExecutionContext(); +} + +void TFProgram::Compile() { + StartExecutionContext(&context); + ConstantFold(); + RemoveUnused(); + AssignVariableNames(*GetBaseBlock()); + EndExecutionContext(); +} + +void TFProgram::ConstantFold() { + ApplyOpTransform(*GetBaseBlock(), [](Op& op) { + OpSpec* spec = GetOpSpec(op.opcode); + if(!spec->const_fold) return; // Skip if no constant folding is defined for this operation + AttributeVector inputs; + for (const auto& arg : op.args->inputs) { + if(!arg->from->attributes.contains("value")) { + return; // Skip if some argument does not have a constant value + } + inputs.push_back(arg->from->attributes.at("value")); + } + Attribute result = spec->const_fold(inputs); + op.attributes["value"] = result; // Set the result as a constant value + op.opcode = "const"; // Change the opcode to constant + op.args->RemoveAll(); // Clear all arguments + }); +} + +void TFProgram::RemoveUnused() { + std::set used_ops = CollectDependencies(values_to_ops(program_outputs)); + IterateOver(*GetBaseBlock(), [&](OpBlock::Iterator& it) { + if (!used_ops.contains(*it)) { + it.remove(); // Remove unused operations + } + }); +} + +// Converts multilevel vmap operations into a sequence of vmaps with concatenated shape +void TFProgram::CombineVmapDepthwise() { + IterateOver(*GetBaseBlock(), [&](OpBlock::Iterator& it) { + static OpBlock* last_vmap_block = nullptr; + static OpBlock* current_vmap_block = nullptr; + static Shape current_shape; + }); +} + +std::string TFProgram::DebugPrint() const { + std::string program_header = "TFProgram(inputs=" + PrintArray(TransformVector(values_to_ops(program_inputs), VariableName), "[", "]") + ") {\n"; + std::string inner_code = PrintBlock(*context.base_block); + inner_code += "return " + PrintArray(TransformVector(values_to_ops(program_outputs), VariableName), "[", "]") + ";\n"; + return program_header + AddIndent(inner_code, 2) + "}\n"; +} +} diff --git a/TensorFrost/Compiler/src/Value.cpp b/TensorFrost/Compiler/src/Value.cpp new file mode 100644 index 00000000..52c17758 --- /dev/null +++ b/TensorFrost/Compiler/src/Value.cpp @@ -0,0 +1,174 @@ +#include "Compiler/Operation.h" +#include "Compiler/ExecutionContext.h" +#include "Compiler/Value.h" +using namespace std; + +namespace TensorFrost { + +Value::Value(Op* operation, int from_index) : op(operation), out_index(from_index) { + if (!op) { + throw std::runtime_error("Value cannot be constructed with a null Op pointer"); + } + if(from_index >= op->output_count) { + throw std::out_of_range("Output index out of range for the operation"); + } +} + +Value::Value(const Op *operation, int from_index) : out_index(from_index) { + op = const_cast(operation); + if (!op) { + throw std::runtime_error("Value cannot be constructed with a null Op pointer"); + } + if(from_index >= op->output_count) { + throw std::out_of_range("Output index out of range for the operation"); + } +} + +Value::Value(float value) : op(constant(value).op) {} +Value::Value(int value) : op(constant(value).op) {} +Value::Value(uint value) : op(constant(value).op) {} +Value::Value(bool value) : op(constant(value).op) {} +Value::Value(const Value &other): op(other.op), out_index(other.out_index) {} + +Value Value::operator+(const Value& other) const { + return value_op("add", {op, other.op}); +} +Value Value::operator-(const Value& other) const { + return value_op("sub", {op, other.op}); +} +Value Value::operator*(const Value& other) const { + return value_op("mul", {op, other.op}); +} +Value Value::operator/(const Value& other) const { + return value_op("div", {op, other.op}); +} +Value Value::operator%(const Value& other) const { + return value_op("mod", {op, other.op}); +} +Value Value::operator==(const Value& other) const { + return value_op("eq", {op, other.op}); +} +Value Value::operator!=(const Value& other) const { + return value_op("ne", {op, other.op}); +} +Value Value::operator<(const Value& other) const { + return value_op("lt", {op, other.op}); +} +Value Value::operator<=(const Value& other) const { + return value_op("le", {op, other.op}); +} +Value Value::operator>(const Value& other) const { + return value_op("gt", {op, other.op}); +} +Value Value::operator>=(const Value& other) const { + return value_op("ge", {op, other.op}); +} +Value Value::operator<<(const Value& other) const { + return value_op("shl", {op, other.op}); +} +Value Value::operator>>(const Value& other) const { + return value_op("shr", {op, other.op}); +} + +Value Value::operator&&(const Value& other) const { + return value_op("land", {op, other.op}); +} +Value Value::operator||(const Value& other) const { + return value_op("lor", {op, other.op}); +} +Value Value::operator!() const { + return value_op("lnot", {op}); +} + +Value Value::operator-() const { + return value_op("neg", {op}); +} +Value Value::operator+() const { + return value_op("pos", {op}); +} +Value Value::operator~() const { + return value_op("not", {op}); +} + +bool Value::Compare(const Value &other) const { + if(op == other.op) return true; + return op->Compare(*other.op); +} + +void Value::Set(Value value) { + Value set = value_op("set", {*this,value}); + this->op = set.op; + this->out_index = set.out_index; +} + +Value Value::operator[](const Values& indices) const { + return load_at_index(*this, indices); +} + +std::vector values_to_ops(const Values& values) { + std::vector ops; + ops.reserve(values.size()); + for (const auto& value : values) { + if (value.op) { + ops.push_back(value.op); + } else { + throw std::runtime_error("Value contains a null Op pointer"); + } + } + return ops; +} + +Values ops_to_values(const std::vector& ops) { + Values values; + values.reserve(ops.size()); + for (const auto& op : ops) { + if (op) { + values.emplace_back(op); + } else { + throw std::runtime_error("Op pointer in vector is null"); + } + } + return values; +} + +void Shape::AddDimension(const Value &dim) { + dimensions.push_back(dim); +} + +void Shape::AddDimensions(const Values &dims) { + dimensions.insert(dimensions.end(), dims.begin(), dims.end()); +} + +bool Shape::Broadcastable(const Shape &other) const { + size_t size = other.dimensions.size(); + if (dimensions.size() < size) { + return false; // Cannot broadcast if this shape has fewer dimensions + } + for (size_t i = 0; i < size; ++i) { + if (!dimensions[i].Compare(other.dimensions[i])) { + return false; + } + } + return true; +} + +Shape ComputeShape(Value x) { + Shape shape; + std::vector parents; + Op* current = x.op; + while(current) { + parents.push_back(current); + current = current->parent_block->parent_op; + } + std::reverse(parents.begin(), parents.end()); + for (const auto& parent : parents) { + OpSpec* spec = GetOpSpec(parent->opcode); + if(spec->props.contains(OpProp::HasShape)) { + shape.AddDimensions(parent->args->Inputs()); + } + } + return shape; +} + +} // namespace TensorFrost + diff --git a/TensorFrost/Frontend/Python/Definitions/PyModule.cpp b/TensorFrost/Frontend/Python/Definitions/PyModule.cpp deleted file mode 100644 index 219bda10..00000000 --- a/TensorFrost/Frontend/Python/Definitions/PyModule.cpp +++ /dev/null @@ -1,62 +0,0 @@ -#include -#include - -#include -#include - -namespace TensorFrost { - -class PyModule : public Module { -public: - using Module::Module; // Inherit constructors - - void assert_parameters() override { - PYBIND11_OVERRIDE(void, Module, assert_parameters); - } - - py::object loss(py::object X, py::object Y) override { - PYBIND11_OVERRIDE_PURE(py::object, Module, loss, X, Y); - } - - py::object forward(py::object X) override { - PYBIND11_OVERRIDE_PURE(py::object, Module, forward, X); - } -}; - -void ModuleDefinitions(py::module& m) { - py::class_(m, "Parameter") - .def(py::init&, TFDataFormat, float, float, bool>(), py::arg("shape"), py::arg("dtype") = TFType::Float, py::arg("random_scale") = -1.0f, py::arg("random_offset") = 0.0f, py::arg("optimize") = true) - .def_readwrite("shape", &Parameter::shape) - .def_readwrite("dtype", &Parameter::dtype) - .def_readwrite("random_scale", &Parameter::random_scale) - .def_readwrite("random_offset", &Parameter::random_offset) - .def("__repr__", [](const Parameter& p) { - return "Parameter(shape=" + std::to_string(p.shape.size()) + ", dtype=" + std::to_string(p.dtype.type) + "( " + std::to_string(p.dtype.size) + ") , random_scale=" + std::to_string(p.random_scale) + ", random_offset=" + std::to_string(p.random_offset) + ", optimize=" + std::to_string(p.optimize) + ")"; - }); - - py::class_(m, "ParameterArray") - .def(py::init<>()) - .def("__getitem__", &ParameterArray::getitem) - .def("__setitem__", &ParameterArray::setitem) - .def("items", &ParameterArray::items); - - py::class_(m, "Module") - .def(py::init(), py::arg("requires_grad") = true) - .def("__getattr__", &Module::getattr) - .def("__setattr__", &Module::setattr) - .def("hasattr", &Module::hasattr) - .def("param_requires_grad", &Module::param_requires_grad) - .def("initialize_input", &Module::initialize_input) - .def("initialize_parameters", &Module::initialize_parameters) - .def("initialize_parameters_native", &Module::initialize_parameters_native) - .def("parameters", &Module::parameters) - .def("named_parameters", &Module::named_parameters) - .def("requires_grads_list", &Module::requires_grads_list) - .def("create_input", &Module::create_input) - .def("update_parameters", &Module::update_parameters) - .def("assert_parameters", &Module::assert_parameters) - .def("loss", &Module::loss) - .def("forward", &Module::forward); -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Frontend/Python/Definitions/PyTensor.cpp b/TensorFrost/Frontend/Python/Definitions/PyTensor.cpp deleted file mode 100644 index 750b8c02..00000000 --- a/TensorFrost/Frontend/Python/Definitions/PyTensor.cpp +++ /dev/null @@ -1,199 +0,0 @@ -#include -#include - -#include - -namespace TensorFrost { - -void DefineOperator( - const std::string& pyname, py::class_& py_tensor, - const std::function& op) { - py_tensor.def(l_op(pyname).c_str(), - [op](const PyTensor& t, const PyTensor& t2) { - return PT(op(T(t), T(t2))); - }); - py_tensor.def(l_op(pyname).c_str(), [op](const PyTensor& t, const float f) { - return PT(op(T(t), Tensor::Constant(f))); - }); - py_tensor.def(l_op(pyname).c_str(), [op](const PyTensor& t, const int i) { - return PT(op(T(t), Tensor::Constant(i))); - }); - py_tensor.def(r_op(pyname).c_str(), [op](const PyTensor& t, const float f) { - return PT(op(Tensor::Constant(f), T(t))); - }); - py_tensor.def(r_op(pyname).c_str(), [op](const PyTensor& t, const int i) { - return PT(op(Tensor::Constant(i), T(t))); - }); -} - -#define LAMBDA_OP(op) \ - [](const Tensor& t1, const Tensor& t2) -> Tensor& { return t1 op t2; } - -void DefineOperators(py::class_& py_tensor) { - DefineOperator("add", py_tensor, LAMBDA_OP(+)); - DefineOperator("sub", py_tensor, LAMBDA_OP(-)); - DefineOperator("mul", py_tensor, LAMBDA_OP(*)); - DefineOperator("div", py_tensor, LAMBDA_OP(/)); - DefineOperator("truediv", py_tensor, LAMBDA_OP(/)); - DefineOperator("mod", py_tensor, LAMBDA_OP(%)); - DefineOperator("eq", py_tensor, LAMBDA_OP(==)); - DefineOperator("ne", py_tensor, LAMBDA_OP(!=)); - DefineOperator("lt", py_tensor, LAMBDA_OP(<)); - DefineOperator("le", py_tensor, LAMBDA_OP(<=)); - DefineOperator("gt", py_tensor, LAMBDA_OP(>)); - DefineOperator("ge", py_tensor, LAMBDA_OP(>=)); - DefineOperator("and", py_tensor, LAMBDA_OP(&&)); - DefineOperator("or", py_tensor, LAMBDA_OP(||)); - DefineOperator("xor", py_tensor, LAMBDA_OP(^)); - DefineOperator("lshift", py_tensor, LAMBDA_OP(<<)); - DefineOperator("rshift", py_tensor, LAMBDA_OP(>>)); - DefineOperator("and_", py_tensor, LAMBDA_OP(&)); - DefineOperator("or_", py_tensor, LAMBDA_OP(|)); - - py_tensor.def("__neg__", [](const PyTensor& t) { return PT(-T(t)); }); - py_tensor.def("__not__", [](const PyTensor& t) { return PT(!T(t)); }); - py_tensor.def("__invert__", [](const PyTensor& t) { return PT(~T(t)); }); - py_tensor.def("__pow__", [](const PyTensor& t, const PyTensor& t2) { - return PT(Tensor::pow(T(t), T(t2))); - }); - py_tensor.def("__pow__", [](const PyTensor& t, float f) { - return PT(Tensor::pow(T(t), Tensor::Constant(f))); - }); - py_tensor.def("__rpow__", [](const PyTensor& t, float f) { - return PT(Tensor::pow(Tensor::Constant(f), T(t))); - }); - py_tensor.def("__matmul__", [](const PyTensor& t, const PyTensor& t2) { - return PT(Tensor::Matmul(T(t), T(t2))); - }); -} - -void PyTensorDefinition(py::module& /*m*/, py::class_& py_tensor) { - // initializers - py_tensor.def(py::init()); - py_tensor.def(py::init()); - py_tensor.def(py::init()); - py_tensor.def(py::init()); - - // properties - py_tensor.def_property_readonly("shape", [](const PyTensor& t) { - return PyTensorsFromTensors(Reverse(t.Get().GetShape())); - }); - py_tensor.def_property_readonly( - "type", [](const PyTensor& t) { return t.Get().GetFormat(); }); - py_tensor.def_property_readonly("indices", [](const PyTensor& t) { - int dim = T(t).GetDimension(); - py::tuple indices(dim); - for (int i = 0; i < dim; i++) { - indices[i] = PT(T(t).Index(dim - i - 1)); - } - return indices; - }); - py_tensor.def_property_readonly("op_name", [](const PyTensor& t) { - return T(t).node_->name; - }); - py_tensor.def("try_get_constant", [](const PyTensor& t) { - if(T(t).node_->name != "const") { - throw std::runtime_error("Can not get constant from non-constant tensor"); - } - return T(t).TryGetConstant(); - }); - py_tensor.def("index",[](const PyTensor& t, int dim) { - int dims = T(t).GetDimension(); - return PT(T(t).Index(dims - dim - 1)); - }); - - py_tensor.def("block_index", [](const PyTensor& t) { - return PT(T(t).BlockIndex()); - }); - - py_tensor.def("block_thread_index", [](const PyTensor& t, int block_dim) { - return PT(T(t).BlockThreadIndex(block_dim)); - }); - - py_tensor.def("detach_grad", [](const PyTensor& t) { - t.Get().DetachGrad(); - return t; - }); - - py_tensor.def("pass_grad", [](const PyTensor& t) { - t.Get().PassGrad(); - return t; - }); - - py_tensor.def("stop_fusion", [](const PyTensor& t) { - t.Get().StopFusion(); - return t; - }); - - py_tensor.def("hint_range", [](const PyTensor& t, py::object min, py::object max) { - if(t.Get().node_->format == TFTypeFloat32) { - t.Get().HintRange(py::cast(min), py::cast(max)); - } else { - t.Get().HintRange(py::cast(min), py::cast(max)); - } - }, py::arg("min"), py::arg("max")); - - // operators - DefineOperators(py_tensor); - - //no way to overload normal setter - //TODO use python AST to generate these functions - py_tensor.def("set", - [](const PyTensor& t, const PyTensor& t2) { T(t).Set(T(t2)); }); - - py_tensor.def_property("val", [](const PyTensor& t) { return t; }, - [](PyTensor& t, const PyTensor& val) { T(t).Set(T(val)); }); - - // indexing - py_tensor.def("__getitem__", [](const PyTensor& t, const PyTensor& t1) { - Tensors indices; - indices.push_back(&t1.Get()); - return PyTensor(&t.Get(), indices); - }); - py_tensor.def("__getitem__", [](const PyTensor& t, py::tuple indices_tuple) { - Tensors indices = Reverse(TensorsFromTuple(indices_tuple)); - return PyTensor(&t.Get(), indices); - }); - - py_tensor.def("__setitem__", - [](const PyTensor& t, const PyTensor& t1, const PyTensor& t2) { - Tensors indices; - indices.push_back(&t1.Get()); - Tensor::Store(t.Get(), T(t2), indices); - }); - py_tensor.def("__setitem__", [](const PyTensor& t, py::tuple indices_tuple, - const PyTensor& t2) { - Tensors indices = Reverse(TensorsFromTuple(indices_tuple)); - Tensor::Store(t.Get(), T(t2), indices); - }); - - py_tensor.def("__setitem__", [](const PyTensor& t, const PyTensor& t1, pybind11::none none) { - //do nothing - }); - py_tensor.def("__setitem__", [](const PyTensor& t, py::tuple indices_tuple, pybind11::none none) { - //do nothing - }); - - // transpose - py_tensor.def("transpose", [](const PyTensor& t, int dim1, int dim2) { - return PT(Tensor::Transpose(T(t), -dim1-1, -dim2-1)); - }, py::arg("dim1") = -2, py::arg("dim2") = -1, "Transpose the tensor"); - - //transpose property - py_tensor.def_property_readonly("T", [](const PyTensor& t) { - return PT(Tensor::Transpose(T(t))); - }); - - py_tensor.def("__str__", [](const PyTensor& t) { - return GetNodeString(t.Get().node_); - }); - py_tensor.def("__repr__", [](const PyTensor& t) { - return GetNodeString(t.Get().node_); - }); - - py_tensor.def("set_debug_name", [](const PyTensor& t, const std::string& name) { - t.Get().SetDebugName(name); - }); -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Frontend/Python/Definitions/TensorFunctions.cpp b/TensorFrost/Frontend/Python/Definitions/TensorFunctions.cpp deleted file mode 100644 index 70637f84..00000000 --- a/TensorFrost/Frontend/Python/Definitions/TensorFunctions.cpp +++ /dev/null @@ -1,353 +0,0 @@ -#include -#include - -#include - -namespace TensorFrost { - -#define UNARY_FUNCTION(name) \ - m.def(#name, [](const PyTensor& t) { return PT(Tensor::name(T(t))); }) - -#define BINARY_FUNCTION(name) \ - m.def(#name, [](const PyTensor& t, const PyTensor& t2) { \ - return PT(Tensor::name(T(t), T(t2))); \ - }) - -#define TERNARY_FUNCTION(name) \ - m.def(#name, [](const PyTensor& t, const PyTensor& t2, const PyTensor& t3) { \ - return PT(Tensor::name(T(t), T(t2), T(t3))); \ - }) - -void TensorFunctionsDefinition(py::module& m) { - UNARY_FUNCTION(copy); - UNARY_FUNCTION(abs); - UNARY_FUNCTION(ceil); - UNARY_FUNCTION(floor); - UNARY_FUNCTION(round); - UNARY_FUNCTION(trunc); - UNARY_FUNCTION(sign); - UNARY_FUNCTION(frac); - UNARY_FUNCTION(sin); - UNARY_FUNCTION(cos); - UNARY_FUNCTION(tan); - UNARY_FUNCTION(asin); - UNARY_FUNCTION(acos); - UNARY_FUNCTION(atan); - UNARY_FUNCTION(sinh); - UNARY_FUNCTION(cosh); - UNARY_FUNCTION(tanh); - UNARY_FUNCTION(exp); - UNARY_FUNCTION(exp2); - UNARY_FUNCTION(log); - UNARY_FUNCTION(log2); - UNARY_FUNCTION(sqrt); - UNARY_FUNCTION(sqr); - UNARY_FUNCTION(rsqrt); - UNARY_FUNCTION(rcp); - - UNARY_FUNCTION(pcg); - UNARY_FUNCTION(pcgf); - UNARY_FUNCTION(reversebits); - - m.def("float", [](const PyTensor& t) { return PT(Tensor::tofloat(T(t))); }); - m.def("uint", [](const PyTensor& t) { return PT(Tensor::touint(T(t))); }); - m.def("int", [](const PyTensor& t) { return PT(Tensor::toint(T(t))); }); - m.def("bool", [](const PyTensor& t) { return PT(Tensor::tobool(T(t))); }); - - m.def("asfloat", [](const PyTensor& t) { return PT(Tensor::asfloat(T(t))); }); - m.def("asuint", [](const PyTensor& t) { return PT(Tensor::asuint(T(t))); }); - m.def("asint", [](const PyTensor& t) { return PT(Tensor::asint(T(t))); }); - - BINARY_FUNCTION(min); - BINARY_FUNCTION(max); - BINARY_FUNCTION(pow); - BINARY_FUNCTION(atan2); - BINARY_FUNCTION(modf); - - BINARY_FUNCTION(grad); - - TERNARY_FUNCTION(clamp); - TERNARY_FUNCTION(fma); - TERNARY_FUNCTION(lerp); - TERNARY_FUNCTION(select); - TERNARY_FUNCTION(smoothstep); - - m.def("scatterAdd", [](const PyTensor& t, const PyTensor& t2) { - Tensor::ScatterAdd(*t.Value(), T(t2), t.Indices()); - }); - - m.def("scatterAddPrev", [](const PyTensor& t, const PyTensor& t2) { - return PT(Tensor::ScatterAddPrev(*t.Value(), T(t2), t.Indices())); - }); - - m.def("scatterMin", [](const PyTensor& t, const PyTensor& t2) { - Tensor::ScatterMin(*t.Value(), T(t2), t.Indices()); - }); - - m.def("scatterMax", [](const PyTensor& t, const PyTensor& t2) { - Tensor::ScatterMax(*t.Value(), T(t2), t.Indices()); - }); - - m.def("scatterOr", [](const PyTensor& t, const PyTensor& t2) { - Tensor::ScatterOr(*t.Value(), T(t2), t.Indices()); - }); - - m.def("scatterAnd", [](const PyTensor& t, const PyTensor& t2) { - Tensor::ScatterAnd(*t.Value(), T(t2), t.Indices()); - }); - - m.def("scatterXor", [](const PyTensor& t, const PyTensor& t2) { - Tensor::ScatterXor(*t.Value(), T(t2), t.Indices()); - }); - - m.def("buffer", [](py::list shape, TFDataFormat type) { - return PT(Tensor::Memory(Reverse(TensorsFromList(shape)), type)); - }, py::arg("shape"), py::arg("type") = TFTypeFloat32); - m.def("buffer", [](std::vector shape, TFDataFormat type) { - return PT(Tensor::Memory(Reverse(shape), type)); - }, py::arg("shape"), py::arg("type") = TFTypeFloat32); - - m.def("local_buffer", [](int size, TFDataFormat type) { - return PT(Tensor::LocalMemory(size, type)); - }, py::arg("size"), py::arg("type") = TFTypeFloat32); - m.def("group_buffer", [](int size, TFDataFormat type) { - return PT(Tensor::GroupMemory(size, type)); - }, py::arg("size"), py::arg("type") = TFTypeFloat32); - m.def("group_barrier", []() { - Tensor::GroupBarrier(); - }); - - m.def("zeros", [](py::list shape, TFDataFormat type) { - return PT(Tensor::Constant(0u, Reverse(TensorsFromList(shape)), type)); - }, py::arg("shape"), py::arg("type") = TFTypeFloat32); - - m.def("const", [](float value, py::list shape) { - return PT(Tensor::Constant(Reverse(TensorsFromList(shape)), value)); - }); - m.def("const", [](float value, std::vector shape) { - return PT(Tensor::Constant(Reverse(shape), value)); - }, py::arg("value"), py::arg("shape") = std::vector{}); - - m.def("const", [](int value, py::list shape) { - return PT(Tensor::Constant(Reverse(TensorsFromList(shape)), value)); - }); - - m.def("const", [](int value, std::vector shape) { - return PT(Tensor::Constant(Reverse(shape), value)); - }, py::arg("value"), py::arg("shape") = std::vector{}); - - m.def("input", [](std::vector shape, TFDataFormat type) { - return PT(Tensor::Input(Reverse(shape), type)); - }, py::arg("shape"), py::arg("type") = TFTypeFloat32); - - m.def("input", [](py::list shape, TFDataFormat type) { - return PT(Tensor::Input(Reverse(TensorsFromList(shape)), type)); - }, py::arg("shape"), py::arg("type") = TFTypeFloat32); - - m.def("index", [](int dim, py::list shape) { - return PT(Tensor::Index(Reverse(TensorsFromList(shape)), dim)); - }); - - m.def("hash", [](py::list shape, const PyTensor& seed) { - return PT(Tensor::Hash(Reverse(TensorsFromList(shape)), T(seed))); - }, py::arg("shape"), py::arg("seed")); - - m.def("random_value", [](py::list shape, const PyTensor& seed) { - return PT(Tensor::Random(Reverse(TensorsFromList(shape)), T(seed))); - }, py::arg("shape"), py::arg("seed")); - - m.def("element_index", [](py::list shape) { - return PT(Tensor::ElementIndex(Reverse(TensorsFromList(shape)))); - }, py::arg("shape")); - - m.def("flat_index", [](py::list shape, py::list indices) { - Tensors shape_tensors = Reverse(TensorsFromList(shape)); - Tensors index_tensors = Reverse(TensorsFromList(indices)); - return PT(Tensor::FlatIndex(shape_tensors, index_tensors)); - }); - - m.def("indices_from_flat_index", [](const PyTensor& index, py::list shape) { - py::tuple indices = py::tuple(shape.size()); - Tensors shape_tensors = Reverse(TensorsFromList(shape)); - Tensors indices_tensors = Reverse(Tensor::IndicesFromFlatIndex(&T(index), shape_tensors)); - for (int i = 0; i < indices_tensors.size(); i++) { - indices[i] = PT(*indices_tensors[i]); - } - return indices; - }); - - m.def("get_copy", [](const PyTensor& t) { return PT(*Tensor::GetCopy(T(t))); }); - - m.def("indices", [](py::list shape) { - Tensors shape_tensors = Reverse(TensorsFromList(shape)); - int dim = (int)shape_tensors.size(); - py::tuple indices = py::tuple(shape_tensors.size()); - for (int i = 0; i < shape_tensors.size(); i++) { - auto t = PT(Tensor::Index(shape_tensors, dim - i - 1)); - indices[i] = t; - } - return indices; - }); - - m.def("indices", [](std::vector shape) { - py::tuple indices = py::tuple(shape.size()); - int dim = (int)shape.size(); - for (int i = 0; i < shape.size(); i++) { - auto t = PT(Tensor::Index(Reverse(shape), dim - i - 1)); - indices[i] = t; - } - return indices; - }); - - - m.def("index_grid", [](py::list begin, py::list end) { - Tensors begin_tensors = Reverse(TensorsFromList(begin)); - Tensors end_tensors = Reverse(TensorsFromList(end)); - Tensors index_grid = Reverse(Tensor::IndexGrid(begin_tensors, end_tensors)); - - py::tuple indices = py::tuple(begin.size()); - for (int i = 0; i < index_grid.size(); i++) { - indices[i] = PT(*index_grid[i]); - } - return indices; - }); - - m.def("index_grid", [](py::list begin, py::list end, py::list step) { - Tensors begin_tensors = Reverse(TensorsFromList(begin)); - Tensors end_tensors = Reverse(TensorsFromList(end)); - Tensors step_tensors = Reverse(TensorsFromList(step)); - Tensors index_grid = Reverse(Tensor::IndexGrid(begin_tensors, end_tensors, step_tensors)); - - py::tuple indices = py::tuple(begin.size()); - for (int i = 0; i < index_grid.size(); i++) { - indices[i] = PT(*index_grid[i]); - } - return indices; - }); - - m.def("reshape", [](const PyTensor& t, py::list shape, TFDataFormat type) { - return PT(Tensor::Reshape(T(t), Reverse(TensorsFromList(shape)), type)); - }, py::arg("t"), py::arg("shape"), py::arg("type") = TFTypeNone); - - m.def("assert_tensor", [](const PyTensor& t, py::list target_shape, TFDataFormat target_type) { - return PT(Tensor::Assert(T(t), Reverse(TensorsFromList(target_shape)), target_type)); - }); - m.def("split_dim", [](const PyTensor& t, const int split_size, const int axis) { - return PT(Tensor::SplitDim(T(t), split_size, -axis-1)); - }, py::arg("t"), py::arg("split_size"), py::arg("axis") = -1); - m.def("merge_dim", [](const PyTensor& t, const int axis, const PyTensor* target_size) { - const Tensor* target_size_ptr = target_size ? &T(*target_size) : nullptr; - return PT(Tensor::MergeDim(T(t), -axis-1, target_size_ptr)); - }, py::arg("t"), py::arg("axis") = -1, py::arg("target_size") = nullptr); - m.def("repeat", [](const PyTensor& t, const PyTensor& repeats, const int axis) { - return PT(Tensor::Repeat(T(t), T(repeats), -axis-1)); - }, py::arg("t"), py::arg("repeats"), py::arg("axis") = -1); - - //algorithm functions - m.def("sum", [](const PyTensor& t, const int axis) { return PT(Tensor::Sum(T(t), -axis-1)); }, - py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Sum the elements of the tensor along the axis"); - - m.def("norm", [](const PyTensor& t, const int axis) { return PT(Tensor::Norm(T(t), -axis-1)); }, - py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Compute the norm of the tensor along the axis"); - - m.def("mean", [](const PyTensor& t, const int axis) { return PT(Tensor::Mean(T(t), -axis-1)); }, - py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Compute the mean of the tensor along the axis"); - - m.def("min", [](const PyTensor& t, const int axis) { return PT(Tensor::Min(T(t), -axis-1)); }, - py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Compute the min of the tensor along the axis"); - - m.def("max", [](const PyTensor& t, const int axis) { return PT(Tensor::Max(T(t), -axis-1)); }, - py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Compute the max of the tensor along the axis"); - - m.def("any", [](const PyTensor& t, const int axis) { return PT(Tensor::Any(T(t), -axis-1)); }, - py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Do an OR operation along the axis"); - - m.def("all", [](const PyTensor& t, const int axis) { return PT(Tensor::All(T(t), -axis-1)); }, - py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Do an AND operation along the axis"); - - m.def("prefix_sum", [](const PyTensor& t, const int axis) { return PT(Tensor::PrefixSum(T(t), -axis-1)); }, - py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Compute the prefix sum of the tensor along the axis"); - - m.def("reverse", [](const PyTensor& t, const int axis) { return PT(Tensor::Reverse(T(t), -axis-1)); }, - py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Reverse the tensor along the axis"); - - m.def("transpose", [](const PyTensor& t, int dim1, int dim2) { - return PT(Tensor::Transpose(T(t), -dim1-1, -dim2-1)); - }, py::arg("t"), py::kw_only(), py::arg("dim1") = -2, py::arg("dim2") = -1, "Transpose the tensor"); - - m.def("unsqueeze", [](const PyTensor& t, int dim) { - return PT(Tensor::Unsqueeze(T(t), -dim-1)); - }, py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Unsqueeze the tensor"); - - m.def("squeeze", [](const PyTensor& t, int dim) { - return PT(Tensor::Squeeze(T(t), -dim-1)); - }, py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Squeeze the tensor"); - - m.def("dot", [](const PyTensor& t, const PyTensor& t2, int axis) { - return PT(Tensor::Dot(T(t), T(t2), -axis-1)); - }, py::arg("t"), py::arg("t2"), py::kw_only(), py::arg("axis") = -1, "Dot product of two tensors"); - - m.def("matmul", [](const PyTensor& t, const PyTensor& t2) { - return PT(Tensor::Matmul(T(t), T(t2))); - }, py::arg("t"), py::arg("t2"), "Matrix multiplication of two tensors"); - - m.def("region_begin", [](const std::string& name) { - Tensor::BeginRegion(name); - }, py::arg("name"), "Begin a debug region"); - - m.def("region_end", [](const std::string& name) { - Tensor::EndRegion(name); - }, py::arg("name"), "End a debug region"); - - m.def("register_custom_operation", [](const std::string& name, vector overloads, py::function impl, py::function vjp) { - auto cpp_impl = [impl](Tensors& output, map inputs, const Tensor* tensor, vector axes) { - py::list input_list; - for (auto& [id, tensor] : inputs) { - input_list.append(PT(*tensor)); - } - py::list output_list = impl(input_list, PT(*tensor), axes).cast(); - for (int i = 0; i < output_list.size(); i++) { - PyTensor* t = output_list[i].cast(); - output.push_back(&t->Get()); - } - }; - - auto cpp_vjp = [vjp](map inputs, const Tensor* gradient, const Tensor* tensor) { - py::list input_list; - for (auto& [id, tensor] : inputs) { - input_list.append(PT(*tensor)); - } - py::list output_list = vjp(input_list, PT(*gradient), PT(*tensor)).cast(); - Tensors gradients; - for (int i = 0; i < output_list.size(); i++) { - PyTensor* t = output_list[i].cast(); - gradients.push_back(&t->Get()); - } - return gradients; - }; - - RegisterAlgorithmicPrimitive(name, overloads, cpp_impl, cpp_vjp); - }, py::arg("name"), py::arg("overloads"), py::arg("impl"), py::arg("vjp"), "Register a custom operation"); - - m.def("custom", [](const std::string& name, py::list inputs, py::list shape) { - Tensors input_tensors = TensorsFromList(inputs); - Tensors shape_tensors = Reverse(TensorsFromList(shape)); - return PT(Tensor::CustomOperation(name, input_tensors, shape_tensors)); - }, py::arg("name"), py::arg("inputs"), py::arg("shape"), "Run custom operation"); - - m.def("custom", [](const std::string& name, py::list inputs) { - Tensors input_tensors = TensorsFromList(inputs); - Tensors shape_tensors = input_tensors[0]->GetShape(); - return PT(Tensor::CustomOperation(name, input_tensors, shape_tensors)); - }, py::arg("name"), py::arg("inputs"), "Run custom operation"); - - m.def("print_value", [](const std::string& name, const PyTensor& t) { - Tensor::PrintValue(name, T(t)); - }, py::arg("name"), py::arg("t"), "Print the value of the tensor"); - - m.def("assert_value", [](const std::string& name, const PyTensor& t) { - Tensor::AssertValue(name, T(t)); - }, py::arg("name"), py::arg("t"), "Assert the value of the tensor"); -} - -} // namespace TensorFrost diff --git a/TensorFrost/Frontend/Python/Definitions/TensorMemory.cpp b/TensorFrost/Frontend/Python/Definitions/TensorMemory.cpp deleted file mode 100644 index d3b96169..00000000 --- a/TensorFrost/Frontend/Python/Definitions/TensorMemory.cpp +++ /dev/null @@ -1,72 +0,0 @@ -#include -#include - -#include -#include - -namespace TensorFrost { - -void TensorMemoryDefinition(py::module& m, - py::class_& py_tensor_mem) { - //define constructors from numpy arrays - py_tensor_mem.def(py::init([](py::array arr) { - return PyTensorMemory(arr); - }), "Create a TensorMemory from a numpy array", py::return_value_policy::take_ownership); - - // "constructor" - m.def( - "tensor", - [](const std::vector& shape, TFDataFormat type) { - return PyTensorMemory(shape, type); - },"Create a TensorMemory with the given shape", py::return_value_policy::take_ownership); - - // "constructor" from numpy array - m.def( - "tensor", - [](py::array arr) { - return new PyTensorMemory(arr); - }, - "Create a TensorMemory from a numpy array", py::return_value_policy::take_ownership); - - // properties - py_tensor_mem.def_property_readonly("shape", [](const PyTensorMemory& t) { - vector shape = t.Shape(); - return py::cast(shape); - }); - - py_tensor_mem.def_property_readonly("type", [](const PyTensorMemory& t) { - return t.GetFormat(); - }); - - py_tensor_mem.def_property_readonly("size", [](const PyTensorMemory& t) { - return GetSize(t.tensor_); - }); - - // to numpy array - py_tensor_mem.def_property_readonly( - "numpy", - [](const PyTensorMemory& t) - -> std::variant, py::array_t, - py::array_t, py::array_t> { - if (t.GetFormat() == TFTypeFloat32) { - return t.ToPyArray(); - } else if (t.GetFormat() == TFTypeInt32) { - return t.ToPyArray(); - } else if (t.GetFormat() == TFTypeUint32) { - return t.ToPyArray(); - } else if (t.GetFormat() == TFTypeBool32) { - return t.ToPyArray(); - } else { - throw std::runtime_error("Unsupported data type for numpy conversion"); - } - }, - "Readback data from tensor memory to a numpy array", py::return_value_policy::take_ownership); - - m.def("allocated_memory", []() { return global_memory_manager->GetAllocatedSize(); }, - "Get the amount of memory currently used by the memory manager"); - - m.def("unused_allocated_memory", []() { return global_memory_manager->GetUnusedAllocatedSize(); }, - "Get the amount of memory currently allocated but not used by the memory manager"); -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Frontend/Python/Definitions/TensorProgram.cpp b/TensorFrost/Frontend/Python/Definitions/TensorProgram.cpp deleted file mode 100644 index f84cdc21..00000000 --- a/TensorFrost/Frontend/Python/Definitions/TensorProgram.cpp +++ /dev/null @@ -1,199 +0,0 @@ -#include -#include - -#include -#include -#include - -namespace TensorFrost { - -void TensorProgramDefinition(py::module& m, - py::class_& tensor_program) { - m.def( - "compile", - [](const py::function& py_evaluate) { - // Extract the name of the Python function - std::string func_name = - py_evaluate.attr("__name__").cast(); - - vector inputs = GetFunctionArguments(py_evaluate); - vector arg_names; - vector arg_props; - for (auto arg : inputs) { - arg_names.push_back(std::get<0>(arg)); - py::object arg_prop = std::get<1>(arg); - py::object arg_default = std::get<2>(arg); - if (py::isinstance(arg_prop)) { - PyTensorArg arg_tensor = arg_prop.cast(); - arg_props.push_back(arg_tensor); - } else { - throw std::runtime_error("Unsupported input type " + std::string(py::str(arg_prop))); - } - } - - TensorProgram& program = *new TensorProgram( - [py_evaluate, arg_names, arg_props]() -> Tensors { - py::gil_scoped_acquire acquire; - std::vector args; - //create inputs from the arguments - for (size_t i = 0; i < arg_names.size(); i++) { - Tensor& input = Tensor::Input(arg_props[i].shape, arg_props[i].type); - input.SetDebugName(arg_names[i]); - PyTensor* py_tensor = new PyTensor(&input); - args.push_back(py_tensor); - } - //convert to py::args - py::args py_args = py::cast(args); - py::object result = py_evaluate(*py_args); - Tensors outputs; - //if the result is a single tensor - if (py::isinstance(result)) { - outputs.push_back(&py::cast(result).Get()); - } else { - auto py_outputs = py::cast>(result); - for (PyTensor output : py_outputs) { - outputs.push_back(&output.Get()); - } - } - return outputs; - }, - func_name); - - py::print(program.PrintProperties()); - return &program; - }, - "Compile a TensorProgram from a python function"); - - tensor_program.def( - "__call__", - [](TensorProgram& program, py::args py_inputs) -> std::variant { - vector inputs_props; - vector temp_numpy_tensors; - for (auto arg : py_inputs) { - if (py::isinstance(arg)) { //if just tensor memory - PyTensorMemory* mem = &arg.cast(); - inputs_props.push_back(arg.cast()); - } else if (py::isinstance(arg)) { //if module then add its parameters - Module* module = &arg.cast(); - py::list params = module->parameters(); - for (auto param : params) { - inputs_props.push_back(param.cast()); - } - } else if (py::isinstance(arg)) { //if numpy array then create pytensormemory from it and add it - py::array arr = arg.cast(); - PyTensorMemory* temp_tensor = new PyTensorMemory(arr); - inputs_props.push_back(py::cast(temp_tensor, py::return_value_policy::take_ownership)); - temp_numpy_tensors.push_back(temp_tensor->tensor_); - } else if (py::isinstance(arg)) { //if list then convert to py::array then create pytensormemory from it and add it - py::array arr = ListToArray(arg.cast()); - PyTensorMemory* temp_tensor = new PyTensorMemory(arr); - inputs_props.push_back(py::cast(temp_tensor, py::return_value_policy::take_ownership)); - temp_numpy_tensors.push_back(temp_tensor->tensor_); - } else { - throw std::runtime_error("Unsupported input type " + std::string(py::str(arg))); - } - } - - vector inputs; - for (auto input : inputs_props) { - PyTensorMemory* mem = input.cast(); - inputs.push_back(mem->tensor_); - } - vector outputs = program.Evaluate(inputs); - - //remove temporary tensors if they are not in the outputs - for (TFTensor* temp_tensor : temp_numpy_tensors) { - bool found = false; - for (TFTensor* output : outputs) { - if (temp_tensor->buffer == output->buffer) { - found = true; - break; - } - } - if (!found) { - global_memory_manager->DeallocateTensor(*temp_tensor); - } - } - - vector output_tensors; - for (size_t i = 0; i < outputs.size(); i++) { - //if any of the outputs are also inputs, then replace them with the input tensors - TFTensor* out = outputs[i]; - bool is_input = false; - for (size_t j = 0; j < inputs_props.size(); j++) { - PyTensorMemory* in = inputs_props[j].cast(); - if (out->buffer == in->tensor_->buffer) { - output_tensors.push_back(inputs_props[j]); - is_input = true; - break; - } - } - if (is_input) { - continue; - } - //otherwise create a new tensor memory - output_tensors.push_back(py::cast(new PyTensorMemory(outputs[i]), py::return_value_policy::take_ownership)); - } - - //if there is only one output, return the tensor memory - if (outputs.size() == 1) { - return output_tensors[0]; - } else { - //convert to py::tuple of PyTensorMemory* - py::tuple py_outputs = py::tuple(outputs.size()); - for (size_t i = 0; i < outputs.size(); i++) { - py_outputs[i] = output_tensors[i]; - } - return py_outputs; - } - }, - "Evaluate the TensorProgram with the given inputs"); - - tensor_program.def( - "list_operations", - [](TensorProgram& program, bool compact) { - std::string listing = "List of operations:\n"; - listing += GetOperationListing(program.ir, compact); - return py::str(listing); - }, - py::arg("compact") = true); - - tensor_program.def("compiled_code", [](TensorProgram& program) { - string code = program.program->generated_code_; - return py::str(code); - }); - - tensor_program.def("get_kernels", [](TensorProgram& program) { - vector kernel_source; - for (auto& kernel : program.program->kernels_) { - kernel_source.push_back(kernel.full_generated_code_); - } - return kernel_source; - }); - - tensor_program.def("get_main_function", [](TensorProgram& program) { - return program.program->main_function_; - }); - - tensor_program.def("get_last_execution_time", [](TensorProgram& program) { - return program.program->last_execution_time; - }); - - m.def("get_all_generated_main_functions", []() { - return global_kernel_manager->GetAllMainFunctions(); - }); - - m.def("get_all_generated_kernels", []() { - return global_kernel_manager->GetAllKernels(); - }); - - m.def("get_cpp_header", []() { - return GetCPPHeader(); - }); - - m.def("get_cpp_implementation", []() { - return GetCPPImplementation(); - }); -} - -} // namespace TensorFrost diff --git a/TensorFrost/Frontend/Python/Definitions/TensorScope.cpp b/TensorFrost/Frontend/Python/Definitions/TensorScope.cpp deleted file mode 100644 index 886bb5a4..00000000 --- a/TensorFrost/Frontend/Python/Definitions/TensorScope.cpp +++ /dev/null @@ -1,118 +0,0 @@ -#include -#include - -#include - -namespace TensorFrost { - -void ScopeDefinitions(py::module& m, py::class_& py_tensor) { - m.def( - "loop", - [](const py::function& body, const PyTensor& begin, const PyTensor& end, - const PyTensor& step) { - // wrap the function to convert the PyTensor to Tensor - std::function f2 = [&body](const Tensor& t) { - py::gil_scoped_acquire acquire; - body(PT(t)); - }; - - Tensor::Loop(T(begin), T(end), T(step), f2); - }, - py::arg("begin") = 0, py::arg("end"), py::arg("step") = 1, - py::arg("body")); - - m.def( - "if_cond", - [](const PyTensor& condition, const py::function& true_body) { - std::function f = [&true_body]() { - py::gil_scoped_acquire acquire; - true_body(); - }; - Tensor::If(T(condition), f); - }, - py::arg("condition"), py::arg("true_body")); - - m.def( - "if_cond", - [](const PyTensor& condition, const py::function& true_body, - const py::function& false_body) { - std::function f1 = [&true_body]() { - py::gil_scoped_acquire acquire; - true_body(); - }; - std::function f2 = [&false_body]() { - py::gil_scoped_acquire acquire; - false_body(); - }; - Tensor::If(T(condition), f1, f2); - }, - py::arg("condition"), py::arg("true_body"), py::arg("false_body")); - - m.def("break_loop", []() { Tensor::Break(); }); - m.def("continue_loop", []() { Tensor::Continue(); }); - - m.def( - "kernel", - [](py::list shape, const py::function& body) { - // wrap the function to convert the PyTensor to Tensor - std::function f2 = - [&body](const Tensors& tensors) { - py::gil_scoped_acquire acquire; - PyTensors py_tensors = PyTensorsFromTensors(tensors); - body(py_tensors); - }; - - Tensor::Kernel(Reverse(TensorsFromList(shape)), f2); - }, - py::arg("shape"), py::arg("body")); - - // m.def( - // "vmap", - // [](py::list inputs, py::list shape, const py::function& func) { - // std::function f = [&func]() { - // py::gil_scoped_acquire acquire; - // func(); - // }; - // Tensor::If(T(condition), f); - // }, - // py::arg("condition"), py::arg("true_body")); - - py_tensor.def("__enter__", &PyTensor::__enter__); - py_tensor.def("__exit__", &PyTensor::__exit__); - - //loop scope - m.def("loop", - [](const PyTensor& begin, const PyTensor& end, const PyTensor& step) { - Tensor& for_loop = Tensor::Loop(T(begin), T(end), T(step)); - return PT(for_loop); - }); - - m.def("loop", - [](const PyTensor& begin, const PyTensor& end) { - Tensor& for_loop = Tensor::Loop(T(begin), T(end), T(PyTensor(1))); - return PT(for_loop); - }); - - m.def("loop", - [](const PyTensor& end) { - Tensor& for_loop = Tensor::Loop(T(PyTensor(0)), T(end), T(PyTensor(1))); - return PT(for_loop); - }); - - //if scope - m.def("if_cond", - [](const PyTensor& condition) { - Tensor& if_cond = Tensor::If(T(condition)); - return PT(if_cond); - }); - - //kernel scope - m.def("kernel", - [](py::list shape, vector group_size) { - Tensors shape_tensors = Reverse(TensorsFromList(shape)); - Tensor& kernel = Tensor::Kernel(shape_tensors, group_size); - return PT(kernel); - }, py::arg("shape"), py::arg("group_size") = vector()); -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Frontend/Python/Definitions/WindowUtils.cpp b/TensorFrost/Frontend/Python/Definitions/WindowUtils.cpp deleted file mode 100644 index bc7b3ad6..00000000 --- a/TensorFrost/Frontend/Python/Definitions/WindowUtils.cpp +++ /dev/null @@ -1,148 +0,0 @@ -#include -#include - -#include -#include - -#include "Backend/RenderDoc.h" - -namespace TensorFrost { - -void WindowDefinitions(py::module& m) { - py::module window = m.def_submodule("window", "Window functions"); - - window.def( - "show", - [](int width, int height, string title) { - ShowWindow(width, height, title.c_str()); - }, - "Show the memory manager window"); - - window.def( - "hide", []() { HideWindow(); }, "Hide the memory manager window"); - - window.def( - "render_frame", [](const PyTensorMemory& t) { RenderFrame(t.tensor_); }, - "Render a frame from the tensor memory"); - - window.def("render_frame", []() { RenderFrame(nullptr); }, - "Render an empty frame"); - - window.def( - "should_close", []() { return WindowShouldClose(); }, - "Check if the window should close"); - - window.def( - "get_mouse_position", []() { return GetMousePosition(); }, - "Get the current mouse position"); - - window.def( - "get_size", []() { return GetWindowSize(); }, - "Get the current window size"); - - window.def( - "is_mouse_button_pressed", - [](int button) { return IsMouseButtonPressed(button); }, - "Check if a mouse button is pressed"); - - window.def( - "is_key_pressed", [](int key) { return IsKeyPressed(key); }, - "Check if a key is pressed"); - - window.attr("MOUSE_BUTTON_0") = GLFW_MOUSE_BUTTON_1; - window.attr("MOUSE_BUTTON_1") = GLFW_MOUSE_BUTTON_2; - window.attr("MOUSE_BUTTON_2") = GLFW_MOUSE_BUTTON_3; - - window.attr("KEY_SPACE") = GLFW_KEY_SPACE; - window.attr("KEY_APOSTROPHE") = GLFW_KEY_APOSTROPHE; - window.attr("KEY_COMMA") = GLFW_KEY_COMMA; - window.attr("KEY_MINUS") = GLFW_KEY_MINUS; - window.attr("KEY_PERIOD") = GLFW_KEY_PERIOD; - - window.attr("KEY_A") = GLFW_KEY_A; - window.attr("KEY_B") = GLFW_KEY_B; - window.attr("KEY_C") = GLFW_KEY_C; - window.attr("KEY_D") = GLFW_KEY_D; - window.attr("KEY_E") = GLFW_KEY_E; - window.attr("KEY_F") = GLFW_KEY_F; - window.attr("KEY_G") = GLFW_KEY_G; - window.attr("KEY_H") = GLFW_KEY_H; - window.attr("KEY_I") = GLFW_KEY_I; - window.attr("KEY_J") = GLFW_KEY_J; - window.attr("KEY_K") = GLFW_KEY_K; - window.attr("KEY_L") = GLFW_KEY_L; - window.attr("KEY_M") = GLFW_KEY_M; - window.attr("KEY_N") = GLFW_KEY_N; - window.attr("KEY_O") = GLFW_KEY_O; - window.attr("KEY_P") = GLFW_KEY_P; - window.attr("KEY_Q") = GLFW_KEY_Q; - window.attr("KEY_R") = GLFW_KEY_R; - window.attr("KEY_S") = GLFW_KEY_S; - window.attr("KEY_T") = GLFW_KEY_T; - window.attr("KEY_U") = GLFW_KEY_U; - window.attr("KEY_V") = GLFW_KEY_V; - window.attr("KEY_W") = GLFW_KEY_W; - window.attr("KEY_X") = GLFW_KEY_X; - window.attr("KEY_Y") = GLFW_KEY_Y; - window.attr("KEY_Z") = GLFW_KEY_Z; - - py::module imgui = m.def_submodule("imgui", "ImGui functions"); - - imgui.def( - "begin", [](string name) { ImGuiBegin(name); }, - "Begin a new ImGui window"); - - imgui.def("end", []() { ImGuiEnd(); }, "End the current ImGui window"); - - imgui.def( - "text", [](string text) { ImGuiText(text); }, - "Add text to the current ImGui window"); - - imgui.def("slider", - [](string text, int* value, int min, int max) { - ImGuiSlider(text, value, min, max); - return *value; - }, "Add a slider to the current ImGui window"); - - imgui.def("slider", [](string text, float* value, float min, float max) { - ImGuiSlider(text, value, min, max); - return *value; - }, "Add a slider to the current ImGui window"); - - imgui.def("button", [](string text) { return ImGuiButton(text); }, - "Add a button to the current ImGui window"); - - imgui.def("checkbox", [](string text, bool* value) { - ImGuiCheckbox(text, value); - return *value; - }, "Add a checkbox to the current ImGui window"); - - imgui.def("plotlines", [](string label, py::array_t values, int values_offset, string overlay_text, float scale_min, float scale_max, py::tuple graph_size, int stride) { - ImGuiPlotLines(label.c_str(), values.data(), (int)values.size(), values_offset, overlay_text.c_str(), scale_min, scale_max, ImVec2(graph_size[0].cast(), graph_size[1].cast()), stride); - }, py::arg("label"), py::arg("values"), py::arg("values_offset") = 0, py::arg("overlay_text") = "", py::arg("scale_min") = FLT_MAX, py::arg("scale_max") = FLT_MAX, py::arg("graph_size") = py::make_tuple(0.0f, 0.0f), py::arg("stride") = sizeof(float)); - - imgui.def("scale_all_sizes", [](float scale) { ImGuiScaleAllSizes(scale); }, - "Scale all ImGui sizes by a factor"); - - imgui.def("add_background_text", [](string text, py::tuple pos, py::tuple color) { - ImGuiAddBackgroundText(text, ImVec2(pos[0].cast(), pos[1].cast()), ImVec4(color[0].cast(), color[1].cast(), color[2].cast(), color[3].cast())); - }, py::arg("text"), py::arg("pos"), py::arg("color")); - - imgui.def("color_picker3", [](string text, py::array_t color) { - ImGuiColorPicker3(text, color.mutable_data()); - return color; - }, py::arg("text"), py::arg("color")); - - imgui.def("color_picker4", [](string text, py::array_t color) { - ImGuiColorPicker4(text, color.mutable_data()); - return color; - }, py::arg("text"), py::arg("color")); - - //renderdoc - m.def("renderdoc_start_capture", []() { StartRenderDocCapture(); }, - "Start a RenderDoc capture"); - m.def("renderdoc_end_capture", []() { EndRenderDocCapture(); }, - "End a RenderDoc capture"); -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Frontend/Python/PyModule.h b/TensorFrost/Frontend/Python/PyModule.h deleted file mode 100644 index 87d2c7d4..00000000 --- a/TensorFrost/Frontend/Python/PyModule.h +++ /dev/null @@ -1,430 +0,0 @@ -#include -#include -#include -#include -#include -#include - -namespace TensorFrost { - -namespace py = pybind11; - -class Parameter { -public: - std::vector shape; - TFDataFormat dtype; - float random_scale; - float random_offset; - string debug_name; - bool optimize; - - Parameter(const std::vector& shape, TFDataFormat dtype, float random_scale = -1.0f, float random_offset = 0.0f, bool requires_grad = true) - : shape(shape), dtype(dtype), random_scale(random_scale), random_offset(random_offset), optimize(requires_grad), debug_name("") {} - - bool CanBeInitialized() { - for (int i = 0; i < shape.size(); i++) { - if (shape[i] == -1) { - return false; - } - } - return true; - } -}; - -class ParameterArray { -public: - map _parameters; //must be sorted - map _requires_grad; - - py::object getitem(size_t index) { - if (_parameters.contains(index)) { - return _parameters[index]; - } - throw py::index_error("Index " + std::to_string(index) + " is not in the ParameterArray"); - } - - void setitem(size_t index, py::object value) { - _parameters[index] = value; - if (py::isinstance(value)) { - _requires_grad[index] = py::cast(value).optimize; - } - } - - vector> items() { - vector> params; - for (auto& param : _parameters) { - params.push_back({param.first, param.second}); - } - return params; - } -}; - -class Module { -public: - enum class AttributeType { - None, - Parameter, - ParameterArray, - Module - }; - - map _attributes; - map _attribute_types; - map _optimize; - vector _attribute_order; - bool optimize = true; - - py::object tf; - - Module(bool requires_grad = true) : optimize(requires_grad) { - tf = py::module::import("TensorFrost"); - } - - py::object getattr(const std::string& name) { - if (_attributes.contains(name)) { - return _attributes[name]; - } - throw py::attribute_error("TensorFrost Module object has no attribute with name '" + name + "'"); - } - - void setattr(const std::string& name, py::object value) { - AttributeType type = AttributeType::None; - bool optimize = true; - if (py::isinstance(value)) { - type = AttributeType::Parameter; - optimize = py::cast(value).optimize; - //if not float then can't optimize - if(py::cast(value).dtype.type != TFType::Float) { - optimize = false; - } - } else if (py::isinstance(value)) { - type = AttributeType::ParameterArray; - } else if (py::isinstance(value)) { - type = AttributeType::Module; - optimize = py::cast(value).optimize; - } else if (py::isinstance(value)) { - PyTensor& tensor = py::cast(value); - tensor.Get().node_->debug_name = name; - } - - bool already_exists = _attributes.contains(name); - if(type == AttributeType::None && already_exists) { - type = _attribute_types[name]; - optimize = _optimize[name]; - } - - _attributes[name] = value; - _attribute_types[name] = type; - _optimize[name] = optimize && this->optimize; - if (!already_exists) _attribute_order.push_back(name); - } - - bool hasattr(const std::string& name) { - return _attributes.contains(name); - } - - bool param_requires_grad(const std::string& name) { - if (_optimize.contains(name)) { - return _optimize[name]; - } - return optimize; - } - - vector> get_attributes_of_type(AttributeType type) { - vector> params; - for (auto& attr : _attribute_order) { - if (_attribute_types[attr] == type) { - params.push_back({attr, _attributes[attr]}); - } - } - return params; - } - - virtual void assert_parameters() {} - - py::object initialize_parameter_native(Parameter& param) { - //Convert to list - vector shape_list; - for (int i = 0; i < param.shape.size(); i++) { - shape_list.push_back(tf.attr("const")(param.shape[i])); - } - py::object result; - if(param.dtype.type == TFType::Float) { - float shape_sum = 0.0f; - for (int i = 0; i < param.shape.size(); i++) { - shape_sum += (float)param.shape[i]; - } - float scale = sqrt(6.0f / shape_sum); - if(param.random_scale >= 0.0f) { - scale = param.random_scale; - } - result = tf.attr("random_value")(shape_list, tf.attr("const")(0u)); - result = result.attr("__mul__")(py::float_(2.0f)); - result = result.attr("__sub__")(py::float_(1.0f)); - result = result.attr("__mul__")(py::float_(scale)); - result = result.attr("__add__")(py::float_(param.random_offset)); - } else if (param.dtype.type == TFType::Int) { //just use zeros - result = tf.attr("const")(0, shape_list); - } else if (param.dtype.type == TFType::Uint) { //just use zeros - result = tf.attr("const")(0, shape_list); - } else { //just use zeros - result = tf.attr("const")(false, shape_list); - } - if(param.debug_name != "") { - result = result.attr("set_debug_name")(param.debug_name); - } - return result; - } - - void initialize_parameters_native() { - for (auto& module : get_attributes_of_type(AttributeType::Module)) { - module.second.attr("initialize_parameters_native")(); - } - - for (auto& param : get_attributes_of_type(AttributeType::Parameter)) { - if(!py::isinstance(param.second)) { - continue; - } - - Parameter& p = py::cast(param.second); - if (!p.CanBeInitialized()) { - //replace all -1 shapes with 1 - for (int i = 0; i < p.shape.size(); i++) { - if (p.shape[i] == -1) { - p.shape[i] = 1; - } - } - } - py::object tensor = initialize_parameter_native(p); - setattr(param.first, tensor); - } - - for (auto& array : get_attributes_of_type(AttributeType::ParameterArray)) { - ParameterArray& param_array = py::cast(array.second); - for (auto& param : param_array._parameters) { - if(!py::isinstance(param.second)) { - continue; - } - Parameter& p = py::cast(param.second); - if (!p.CanBeInitialized()) { - continue; - } - py::object tensor = initialize_parameter_native(p); - param_array.setitem(param.first, tensor); - } - } - } - - void initialize_input() { - for (auto& module : get_attributes_of_type(AttributeType::Module)) { - module.second.attr("initialize_input")(); - } - - for (auto& param : get_attributes_of_type(AttributeType::Parameter)) { - Parameter& p = py::cast(param.second); - py::object tensor = tf.attr("input")(p.shape, p.dtype); - if(p.debug_name != "") { - tensor = tensor.attr("set_debug_name")(p.debug_name); - } - setattr(param.first, tensor); - } - - for (auto& array : get_attributes_of_type(AttributeType::ParameterArray)) { - ParameterArray& param_array = py::cast(array.second); - for (auto& param : param_array._parameters) { - Parameter& p = py::cast(param.second); - py::object tensor = tf.attr("input")(p.shape, p.dtype); - if(p.debug_name != "") { - tensor = tensor.attr("set_debug_name")(p.debug_name); - } - param_array.setitem(param.first, tensor); - } - } - - assert_parameters(); - } - - py::object initialize_parameter(Parameter& param) { - py::object np = py::module::import("numpy"); - py::object random = np.attr("random"); - - // Convert the shape vector to a tuple - py::tuple shape_tuple = py::cast(param.shape); - - if(param.dtype.type == TFType::Float) { - // Generate uniform random values instead of normal - py::array_t arr = random.attr("uniform")(-1.0f, 1.0f, shape_tuple).cast>(); - float shape_sum = 0.0f; - for (int i = 0; i < param.shape.size(); i++) { - shape_sum += (float)param.shape[i]; - } - float scale = sqrt(6.0f / shape_sum); - if(param.random_scale >= 0.0f) { - scale = param.random_scale; - } - - arr = arr.attr("__mul__")(py::float_(scale)); - arr = arr.attr("__add__")(py::float_(param.random_offset)); - return tf.attr("tensor")(arr); - } else if (param.dtype.type == TFType::Int) { //just use zeros - py::array_t arr = np.attr("zeros")(shape_tuple).cast>(); - return tf.attr("tensor")(arr); - } else if (param.dtype.type == TFType::Uint) { //just use zeros - py::array_t arr = np.attr("zeros")(shape_tuple).cast>(); - return tf.attr("tensor")(arr); - } else { //just use zeros - py::array_t arr = np.attr("zeros")(shape_tuple).cast>(); - return tf.attr("tensor")(arr); - } - } - - void initialize_parameters() { - for (auto& module : get_attributes_of_type(AttributeType::Module)) { - module.second.attr("initialize_parameters")(); - } - - for (auto& param : get_attributes_of_type(AttributeType::Parameter)) { - if(!py::isinstance(param.second)) { - continue; - } - - Parameter& p = py::cast(param.second); - if (!p.CanBeInitialized()) { - continue; - } - py::object tensor = initialize_parameter(p); - setattr(param.first, tensor); - } - - for (auto& array : get_attributes_of_type(AttributeType::ParameterArray)) { - ParameterArray& param_array = py::cast(array.second); - for (auto& param : param_array._parameters) { - if(!py::isinstance(param.second)) { - continue; - } - Parameter& p = py::cast(param.second); - if (!p.CanBeInitialized()) { - continue; - } - py::object tensor = initialize_parameter(p); - param_array.setitem(param.first, tensor); - } - } - } - - py::list parameters() { - py::list params; - for (auto& module : get_attributes_of_type(AttributeType::Module)) { - params += module.second.attr("parameters")(); - } - - for (auto& param : get_attributes_of_type(AttributeType::Parameter)) { - params.append(param.second); - } - - for (auto& array : get_attributes_of_type(AttributeType::ParameterArray)) { - ParameterArray& param_array = py::cast(array.second); - for (auto& param : param_array._parameters) { - params.append(param.second); - } - } - return params; - } - - py::list named_parameters() { - py::list params; - for (auto& module : get_attributes_of_type(AttributeType::Module)) { - params += module.second.attr("named_parameters")(); - } - - for (auto& param : get_attributes_of_type(AttributeType::Parameter)) { - params.append(py::make_tuple(param.first, param.second)); - } - - for (auto& array : get_attributes_of_type(AttributeType::ParameterArray)) { - ParameterArray& param_array = py::cast(array.second); - for (auto& param : param_array._parameters) { - params.append(py::make_tuple(array.first + "[" + std::to_string(param.first) + "]", param.second)); - } - } - return params; - } - - py::list requires_grads_list() { - py::list requires_grads; - for (auto& module : get_attributes_of_type(AttributeType::Module)) { - requires_grads.append( param_requires_grad(module.first) ); - } - - for (auto& param : get_attributes_of_type(AttributeType::Parameter)) { - requires_grads.append( param_requires_grad(param.first) ); - } - - for (auto& array : get_attributes_of_type(AttributeType::ParameterArray)) { - ParameterArray& param_array = py::cast(array.second); - bool requires_grad = param_requires_grad(array.first); - for (auto& param : param_array._parameters) { - requires_grads.append( param_array._requires_grad[param.first] && requires_grad ); - } - } - return requires_grads; - } - - py::list create_input(py::args args) { - py::list inputs = parameters(); - inputs += args; - return inputs; - } - - void update_parameters(py::object parameter_values) { - py::list params; - if (py::isinstance(parameter_values)) { - params = parameter_values; - } else if (py::isinstance(parameter_values)) { - params = py::list(parameter_values); - } else { - throw py::type_error("parameter_values must be a list or tuple"); - } - - int index = 0; - - std::function update_params; - - update_params = [&](Module& module) { - for (auto& module_item : module.get_attributes_of_type(AttributeType::Module)) { - update_params(py::cast(module_item.second)); - } - - for (auto& param : module.get_attributes_of_type(AttributeType::Parameter)) { - if (index >= py::len(params)) { - throw py::index_error("Provided more than " + std::to_string(index) + " values, but expected " + std::to_string(py::len(params))); - } - module.setattr(param.first, params[index]); - index++; - } - - for (auto& array : module.get_attributes_of_type(AttributeType::ParameterArray)) { - ParameterArray& param_array = py::cast(array.second); - for (auto& param : param_array._parameters) { - if (index >= py::len(params)) { - throw py::index_error("Provided more than " + std::to_string(index) + " values, but expected " + std::to_string(py::len(params))); - } - param_array.setitem(param.first, params[index]); - index++; - } - } - }; - - update_params(*this); - } - - virtual py::object loss(py::object X, py::object Y) { - throw std::runtime_error("Not implemented"); - } - - virtual py::object forward(py::object X) { - throw std::runtime_error("Not implemented"); - } -}; - -} \ No newline at end of file diff --git a/TensorFrost/Frontend/Python/PyTensor.cpp b/TensorFrost/Frontend/Python/PyTensor.cpp deleted file mode 100644 index a1ec2047..00000000 --- a/TensorFrost/Frontend/Python/PyTensor.cpp +++ /dev/null @@ -1,226 +0,0 @@ -#include "Frontend/Python/PyTensor.h" - -#include "PyTensorMemory.h" - -namespace TensorFrost { - -PyTensors PyTensorsFromTuple(const py::tuple& tuple) { - PyTensors tensors; - for (auto arg : tuple) { - tensors.push_back(&arg.cast()); - } - return tensors; -} - -Tensors TensorsFromTuple(const py::tuple& tuple) { - Tensors tensors; - for (auto arg : tuple) { - tensors.push_back(&arg.cast().Get()); - } - return tensors; -} - -tuple SliceToTensors(const py::slice& slice) { - PyObject* pyslice = slice.ptr(); - PySliceObject* slice_obj = (PySliceObject*)pyslice; - PyObject* start = slice_obj->start; - PyObject* stop = slice_obj->stop; - PyObject* step = slice_obj->step; - - py::object start_obj = py::reinterpret_borrow(start); - py::object stop_obj = py::reinterpret_borrow(stop); - py::object step_obj = py::reinterpret_borrow(step); - - PyTensor* start_tensor = nullptr; - PyTensor* stop_tensor = nullptr; - PyTensor* step_tensor = nullptr; - - start_tensor = &start_obj.cast(); - stop_tensor = &stop_obj.cast(); - step_tensor = &step_obj.cast(); - - return {start_tensor, stop_tensor, step_tensor}; -} - -//Tensors TensorsFromTensorIndices(const Tensor* t, const py::tuple& tuple) { -// Tensors tensors; -// for (auto arg : tuple) { -// //if index is a tensor -// if (py::isinstance(arg)) { -// tensors.push_back(&arg.cast().Get()); -// } // if index is a slice -// else if (py::isinstance(arg)) { -// auto slice = arg.cast(); -// //get native python slice -// PyObject* pyslice = slice.ptr(); -// PySliceObject* slice_obj = (PySliceObject*)pyslice; -// //get start, stop, and step -// PyObject* start = slice_obj->start; -// PyObject* stop = slice_obj->stop; -// PyObject* step = slice_obj->step; -// -// //convert start, stop, and step to py::object -// py::object start_obj = py::reinterpret_borrow(start); -// py::object stop_obj = py::reinterpret_borrow(stop); -// py::object step_obj = py::reinterpret_borrow(step); -// -// //try to cast to PyTensor -// PyTensor* start_tensor = nullptr; -// PyTensor* stop_tensor = nullptr; -// PyTensor* step_tensor = nullptr; -// -// try { -// start_tensor = &start_obj.cast(); -// } catch (const py::cast_error& e) { -// //do nothing -// } -// -// try { -// stop_tensor = &stop_obj.cast(); -// } catch (const py::cast_error& e) { -// //do nothing -// } -// -// try { -// step_tensor = &step_obj.cast(); -// } catch (const py::cast_error& e) { -// //do nothing -// } -// -// //if start, stop, and step are all PyTensor -// -// -// -// -// -// else { -// throw std::invalid_argument("Invalid index type"); -// } -// } -// return tensors; -//} - -PyTensors PyTensorsFromList(const py::list& list) { - PyTensors tensors; - for (auto arg : list) { - tensors.push_back(&arg.cast()); - } - return tensors; -} - -Tensors TensorsFromList(const py::list& list) { - Tensors tensors; - for (auto arg : list) { - tensors.push_back(&arg.cast().Get()); - } - return tensors; -} - -PyTensors PyTensorsFromTensors(const Tensors& tensors) { - PyTensors py_tensors; - for (const auto* tensor : tensors) { - py_tensors.push_back(new PyTensor(tensor)); - } - return py_tensors; -} - -std::variant PyTensorsToTupleVariant(const PyTensors &tensors) { - if (tensors.size() == 1) { - //if there is only one tensor, return the tensor - return tensors[0]; - } else { - //convert to py::tuple of PyTensor* - return py::tuple(py::cast(tensors)); - } -} - -void UpdateTensorNames() { - PyObject* p = PyEval_GetLocals(); - py::dict all_names = py::reinterpret_borrow(p ? p : py::module_::import("__main__").attr("__dict__").ptr()); - - for (auto item : all_names) { - std::string var_name = py::str(item.first); - py::object var_value = py::reinterpret_borrow(item.second); - if (py::isinstance(var_value)) { - PyTensor& py_tensor = var_value.cast(); - const Tensor* tensor = &py_tensor.Get(); - tensor->SetDebugName(var_name); - } - } -} - -std::vector GetFunctionArguments(const py::function& func) { - PyObject* fn = func.ptr(); - PyObject* code_obj = PyFunction_GetCode(fn); - if (!code_obj) { - throw std::runtime_error("Could not retrieve code object"); - } - - PyObject* varnames = PyObject_GetAttrString(code_obj, "co_varnames"); - if (!varnames || !PyTuple_Check(varnames)) { - Py_XDECREF(varnames); - throw std::runtime_error("Could not retrieve varnames or varnames is not a tuple"); - } - - PyObject* argcount_obj = PyObject_GetAttrString(code_obj, "co_argcount"); - if (!argcount_obj) { - Py_XDECREF(varnames); - throw std::runtime_error("Could not retrieve argument count"); - } - int arg_count = PyLong_AsLong(argcount_obj); - Py_XDECREF(argcount_obj); - if (PyErr_Occurred()) { - Py_XDECREF(varnames); - throw std::runtime_error("Could not retrieve argument count"); - } - - PyObject* annotations = PyObject_GetAttrString(fn, "__annotations__"); - PyObject* defaults = PyObject_GetAttrString(fn, "__defaults__"); - - std::vector arg_info_list; - for (int i = 0; i < arg_count; ++i) { - PyObject* name = PyTuple_GetItem(varnames, i); - if (!name || !PyUnicode_Check(name)) { - Py_XDECREF(varnames); - Py_XDECREF(annotations); - Py_XDECREF(defaults); - throw std::runtime_error("Argument name is not a valid Unicode string"); - } - std::string arg_name = PyUnicode_AsUTF8(name); - - // Get annotation - PyObject* annotation = annotations ? PyDict_GetItemString(annotations, arg_name.c_str()) : nullptr; - - // Get default value - PyObject* default_val = (defaults && PyTuple_Check(defaults) && i >= (arg_count - PyTuple_Size(defaults))) - ? PyTuple_GetItem(defaults, i - (arg_count - PyTuple_Size(defaults))) - : nullptr; - - py::object annotation_obj = py::reinterpret_borrow(annotation); - py::object default_obj = py::reinterpret_borrow(default_val); - - arg_info_list.emplace_back(arg_name, annotation_obj, default_obj); - } - - Py_XDECREF(varnames); - Py_XDECREF(annotations); - Py_XDECREF(defaults); - - return arg_info_list; -} - -py::array ListToArray(py::list input_list) { - // Get the numpy module - py::module np = py::module::import("numpy"); - - // Convert the list to a numpy array - py::array np_array = np.attr("array")(input_list); - - return np_array; -} - -std::string r_op(const std::string& name) { return "__r" + name + "__"; } - -std::string l_op(const std::string& name) { return "__" + name + "__"; } - -} // namespace TensorFrost diff --git a/TensorFrost/Frontend/Python/PyTensor.h b/TensorFrost/Frontend/Python/PyTensor.h deleted file mode 100644 index 84a50c4f..00000000 --- a/TensorFrost/Frontend/Python/PyTensor.h +++ /dev/null @@ -1,104 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace TensorFrost { - -#define PT(tensor) PyTensor(&(tensor)) -#define T(tensor) (tensor).Get() - -namespace py = pybind11; - -void UpdateTensorNames(); -class PyTensor; - -using PyTensors = std::vector; - -PyTensors PyTensorsFromTuple(const py::tuple& tuple); -Tensors TensorsFromTuple(const py::tuple& tuple); -PyTensors PyTensorsFromList(const py::list& list); -Tensors TensorsFromList(const py::list& list); -PyTensors PyTensorsFromTensors(const Tensors& tensors); -std::variant PyTensorsToTupleVariant(const PyTensors& tensors); - -using ArgInfo = std::tuple; // (name, annotation, default) - -vector GetFunctionArguments(const py::function& func); - -py::array ListToArray(py::list input_list); - -// Tensor wrapper for python -class PyTensor { - Tensor* tensor_; - Tensors indices; - Tensor* value = nullptr; - - public: - explicit PyTensor(Tensor* tensor) : tensor_(tensor) {} - explicit PyTensor(const Tensor* tensor) : tensor_(const_cast(tensor)) {} - ~PyTensor() { UpdateTensorNames(); } - - //tensor view constructor - explicit PyTensor(const Tensor* value, Tensors& indices) - : value(const_cast(value)), indices(std::move(indices)) { - tensor_ = &Tensor::Load(*value, this->indices); - } - - const Tensor& Get() const { return *tensor_; } - - Tensor* Value() const { - if (value == nullptr) { - throw std::runtime_error("Not a tensor view"); - } - return value; - } - - Tensors Indices() const { - if (value == nullptr) { - throw std::runtime_error("Not a tensor view"); - } - return indices; - } - - explicit PyTensor(float value) { tensor_ = &Tensor::Constant(value); } - explicit PyTensor(int value) { tensor_ = &Tensor::Constant(value); } - explicit PyTensor(unsigned int value) { tensor_ = &Tensor::Constant(value); } - explicit PyTensor(bool value) { tensor_ = &Tensor::Constant(value); } - - std::variant __enter__() { - //py::print("Entering node scope"); - std::variant entered = tensor_->Enter(); - if (std::holds_alternative(entered)) { - return new PyTensor(std::get(entered)); - } else { - auto tensors = std::get(entered); - //convert to py::tuple of PyTensor* - return PyTensorsToTupleVariant(PyTensorsFromTensors(Reverse(tensors))); - } - } - - void __exit__(py::object exc_type, py::object exc_value, - py::object traceback) { - //py::print("Exiting node scope"); - tensor_->Exit(); - } -}; - -class PyTensorArg { -public: - std::vector shape; - TFDataFormat type; - - PyTensorArg(std::vector shape, TFDataFormat type) - : shape(std::move(shape)), type(type) {} -}; - -std::string r_op(const std::string& name); -std::string l_op(const std::string& name); - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Frontend/Python/PyTensorMemory.cpp b/TensorFrost/Frontend/Python/PyTensorMemory.cpp deleted file mode 100644 index 75b1ee34..00000000 --- a/TensorFrost/Frontend/Python/PyTensorMemory.cpp +++ /dev/null @@ -1,106 +0,0 @@ -#include "Frontend/Python/PyTensor.h" -#include "Frontend/Python/PyTensorMemory.h" - -namespace TensorFrost { -PyTensorMemory::PyTensorMemory(py::array arr) { - py::buffer_info info = arr.request(); - - // Get the shape - std::vector shape; - for (size_t i = 0; i < (size_t)info.ndim; i++) { - shape.push_back(info.shape[i]); - } - - // Create the data vector - std::vector data; - data.reserve(info.size); - - // Determine the data type and conversion function - std::function convert; - TFDataFormat format = TFTypeNone; - switch (info.format[0]) { - case 'f': // float32 - convert = [](char* ptr) { return *reinterpret_cast(ptr); }; - format = TFTypeFloat32; - break; - case 'i': // int32 - convert = [](char* ptr) { return static_cast(*reinterpret_cast(ptr)); }; - format = TFTypeInt32; - break; - case 'q': // int64 (convert to int32 before casting) - convert = [](char* ptr) { - int32_t val = (int32_t)*reinterpret_cast(ptr); - return *reinterpret_cast(&val); - }; - format = TFTypeInt32; - break; - case 'Q': // uint64 (convert to uint32 before casting) - convert = [](char* ptr) { - uint32_t val = (uint32_t)*reinterpret_cast(ptr); - return val; - }; - format = TFTypeUint32; - break; - case 'L': // uint32 - case 'I': // uint32 - convert = [](char* ptr) { return *reinterpret_cast(ptr); }; - format = TFTypeUint32; - break; - case '?': // bool - convert = [](char* ptr) { return static_cast(*reinterpret_cast(ptr)); }; - format = TFTypeBool32; - break; - case 'd': // float64 (convert to float32 before casting) - convert = [](char* ptr) { float val = (float)*reinterpret_cast(ptr); return *reinterpret_cast(&val); }; - format = TFTypeFloat32; - break; - case 'l': // int64 (convert to int32 before casting) - convert = [](char* ptr) { int32_t val = (int32_t)*reinterpret_cast(ptr); return *reinterpret_cast(&val); }; - format = TFTypeInt32; - break; - default: - throw std::runtime_error("Unsupported data type to create TensorMemory from numpy array, format: " + std::string(info.format)); - } - - // Define a recursive lambda function for multi-dimensional iteration - std::function&)> iter_dims; - iter_dims = [&](const size_t dim, std::vector& indices) { - if (dim == info.ndim) { - // Calculate the actual memory address using strides - char* ptr = static_cast(info.ptr); - for (size_t i = 0; i < (size_t)info.ndim; ++i) { - ptr += indices[i] * info.strides[i]; - } - data.push_back(convert(ptr)); - } else { - for (indices[dim] = 0; indices[dim] < (size_t)info.shape[dim]; ++indices[dim]) { - iter_dims(dim + 1, indices); - } - } - }; - - // Start the multidimensional iteration - std::vector start_indices(info.ndim, 0); - iter_dims(0, start_indices); - - // Allocate the memory - tensor_ = global_memory_manager->AllocateTensorWithData(shape, data, format); -} - -vector TensorMemoryFromTuple(const py::tuple& tuple) { - vector memories; - for (auto arg : tuple) { - memories.push_back(&arg.cast()); - } - return memories; -} - -vector TensorMemoryFromList(const py::list& list) { - vector memories; - for (auto arg : list) { - memories.push_back(&arg.cast()); - } - return memories; -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Frontend/Python/PyTensorMemory.h b/TensorFrost/Frontend/Python/PyTensorMemory.h deleted file mode 100644 index e5ca77ed..00000000 --- a/TensorFrost/Frontend/Python/PyTensorMemory.h +++ /dev/null @@ -1,60 +0,0 @@ -#include -#include -#include -#include -#include -#include - -namespace TensorFrost { - -namespace py = pybind11; - -// Tensor wrapper for python -class PyTensorMemory { - public: - TFTensor* tensor_; - - explicit PyTensorMemory(TFTensor* tensor) : tensor_(tensor) {} - - PyTensorMemory(vector shape, TFDataFormat type = TFTypeFloat32) { - tensor_ = global_memory_manager->AllocateTensor(shape, type); - } - - TFDataFormat GetFormat() const { - return tensor_->format; - } - - PyTensorMemory(py::array arr); - - template - py::array_t ToPyArray() const { - // Get the shape - std::vector shape = GetShape(tensor_); - - // Create the numpy array - py::array_t arr(shape); - - // Copy the data - std::vector data = global_memory_manager->Readback(tensor_); - T* ptr = static_cast(arr.request().ptr); - for (int i = 0; i < data.size(); i++) { - ptr[i] = *(reinterpret_cast(&data[i])); - } - - return arr; - } - - vector Shape() const { - return TensorFrost::GetShape(tensor_); - } - - ~PyTensorMemory() { - global_memory_manager->DeallocateTensor(*tensor_); - } - -}; - -vector TensorMemoryFromTuple(const py::tuple& tuple); -vector TensorMemoryFromList(const py::list& list); - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Frontend/Python/PybindModule.cpp b/TensorFrost/Frontend/Python/PybindModule.cpp deleted file mode 100644 index c3351736..00000000 --- a/TensorFrost/Frontend/Python/PybindModule.cpp +++ /dev/null @@ -1,128 +0,0 @@ -#include -#include - -#include -#include -#include - -namespace TensorFrost { - -void PyTensorDefinition(py::module&, py::class_&); -void TensorFunctionsDefinition(py::module&); -void TensorProgramDefinition(py::module&, py::class_&); -void TensorMemoryDefinition(py::module& m, - py::class_& py_tensor_mem); -void WindowDefinitions(py::module& m); -void ScopeDefinitions(py::module& m, py::class_& py_tensor); -void ModuleDefinitions(py::module& m); - -PYBIND11_MODULE(TensorFrost, m) { - m.doc() = "TensorFrost library"; - auto data_type = py::enum_(m, "TFType"); - auto backend_type = py::enum_(m, "BackendType"); - auto code_gen_lang = py::enum_(m, "CodeGenLang"); - auto py_tensor = py::class_(m, "Tensor"); - auto tensor_program = py::class_(m, "TensorProgram"); - auto py_tensor_mem = py::class_(m, "TensorMemory"); - auto py_tensor_arg = py::class_(m, "Arg"); - - data_type.value("float", TFType::Float); - data_type.value("int", TFType::Int); - data_type.value("uint", TFType::Uint); - data_type.value("bool", TFType::Bool); - - auto data_format = py::class_(m, "TFDataFormat"); - data_format.def(py::init()); - data_format.def_readwrite("type", &TFDataFormat::type); - data_format.def_readwrite("size", &TFDataFormat::size); - // Add printers for the enums - data_format.def("__repr__", [](const TFDataFormat& a) { - return ""; - }); - data_format.def("__str__", [](const TFDataFormat& a) { - return ""; - }); - data_type.def("__repr__", [](TFType a) { - return ""; - }); - data_type.def("__str__", [](TFType a) { - return ""; - }); - - backend_type.value("cpu", BackendType::CPU); - backend_type.value("vulkan", BackendType::Vulkan); - backend_type.value("opengl", BackendType::OpenGL); - backend_type.value("codegen", BackendType::CodeGen); - code_gen_lang.value("cpp", CodeGenLang::CPP); - code_gen_lang.value("glsl", CodeGenLang::GLSL); - code_gen_lang.value("hlsl", CodeGenLang::HLSL); - - data_format.def(py::self == py::self); - backend_type.def("__eq__", [](BackendType a, BackendType b) { - return a == b; - }); - code_gen_lang.def("__eq__", [](CodeGenLang a, CodeGenLang b) { - return a == b; - }); - data_type.def("__eq__", [](TFType a, TFType b) { - return a == b; - }); - - m.attr("float32") = TFTypeFloat32; - m.attr("int32") = TFTypeInt32; - m.attr("uint32") = TFTypeUint32; - m.attr("bool1") = TFTypeBool32; - - m.attr("cpu") = BackendType::CPU; - m.attr("vulkan") = BackendType::Vulkan; - m.attr("opengl") = BackendType::OpenGL; - m.attr("codegen") = BackendType::CodeGen; - - m.attr("cpp_lang") = CodeGenLang::CPP; - m.attr("glsl_lang") = CodeGenLang::GLSL; - m.attr("hlsl_lang") = CodeGenLang::HLSL; - - PyTensorDefinition(m, py_tensor); - - // implicit conversion from TensorView to PyTensor - py::implicitly_convertible(); - py::implicitly_convertible(); - py::implicitly_convertible(); - py::implicitly_convertible(); - - TensorFunctionsDefinition(m); - TensorProgramDefinition(m, tensor_program); - TensorMemoryDefinition(m, py_tensor_mem); - WindowDefinitions(m); - ScopeDefinitions(m, py_tensor); - ModuleDefinitions(m); - - py_tensor_arg.def(py::init([](py::list shape, TFDataFormat type) { - std::vector shape_vec; - for (auto& s : shape) { - shape_vec.push_back(s.cast()); - } - return PyTensorArg(shape_vec, type); - }), "Create a TensorArg with the given shape and type", py::return_value_policy::take_ownership); - - m.def("current_backend", []() { - return current_backend; - }, "Get the current backend"); - - m.def("initialize", - [](BackendType backend_type, const std::string& kernel_compile_options, CodeGenLang kernel_lang) { - InitializeBackend(backend_type, kernel_compile_options, kernel_lang); - }, py::arg("backend_type") = BackendType::CPU, py::arg("kernel_compile_options") = "", py::arg("kernel_lang") = CodeGenLang::None, "Initialize the backend"); - - m.def("strip_debug_info", [](bool strip) { - strip_debug_names = strip; - }, py::arg("strip") = true, "Strip debug info from the kernel"); - -#ifdef NDEBUG - py::print("TensorFrost module loaded!"); -#else - py::print("TensorFrost module loaded in debug mode! Expect slow performance."); -#endif -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/PybindModule.cpp b/TensorFrost/PybindModule.cpp new file mode 100644 index 00000000..8b88e2f5 --- /dev/null +++ b/TensorFrost/PybindModule.cpp @@ -0,0 +1,247 @@ +#include +#include + +// #include +// #include +#include +#include +#include +#include +#include + +#include "TensorFrost.h" +#include "Backend/RenderDoc.h" +#include "Definitions/VulkanBindings.h" + +namespace py = pybind11; + +namespace TensorFrost { +// +// void PyTensorDefinition(py::module&, py::class_&); +// void TensorFunctionsDefinition(py::module&); +// void TensorProgramDefinition(py::module&, py::class_&); +// void TensorMemoryDefinition(py::module& m, +// py::class_& py_tensor_mem); +// void ScopeDefinitions(py::module& m, py::class_& py_tensor); +// void ModuleDefinitions(py::module& m); + +PYBIND11_MODULE(TensorFrost, m) { + m.doc() = "TensorFrost library"; + // auto data_type = py::enum_(m, "TFType"); + // auto backend_type = py::enum_(m, "BackendType"); + // auto code_gen_lang = py::enum_(m, "CodeGenLang"); + // auto py_tensor = py::class_(m, "Tensor"); + // auto tensor_program = py::class_(m, "TensorProgram"); + // auto py_tensor_mem = py::class_(m, "TensorMemory"); + // auto py_tensor_arg = py::class_(m, "Arg"); + // + // data_type.value("float", TFType::Float); + // data_type.value("int", TFType::Int); + // data_type.value("uint", TFType::Uint); + // data_type.value("bool", TFType::Bool); + // + // auto data_format = py::class_(m, "TFDataFormat"); + // data_format.def(py::init()); + // data_format.def_readwrite("type", &TFDataFormat::type); + // data_format.def_readwrite("size", &TFDataFormat::size); + // // Add printers for the enums + // data_format.def("__repr__", [](const TFDataFormat& a) { + // return ""; + // }); + // data_format.def("__str__", [](const TFDataFormat& a) { + // return ""; + // }); + // data_type.def("__repr__", [](TFType a) { + // return ""; + // }); + // data_type.def("__str__", [](TFType a) { + // return ""; + // }); + // + // backend_type.value("cpu", BackendType::CPU); + // backend_type.value("vulkan", BackendType::Vulkan); + // backend_type.value("opengl", BackendType::OpenGL); + // backend_type.value("codegen", BackendType::CodeGen); + // code_gen_lang.value("cpp", CodeGenLang::CPP); + // code_gen_lang.value("glsl", CodeGenLang::GLSL); + // code_gen_lang.value("hlsl", CodeGenLang::HLSL); + // + // data_format.def(py::self == py::self); + // backend_type.def("__eq__", [](BackendType a, BackendType b) { + // return a == b; + // }); + // code_gen_lang.def("__eq__", [](CodeGenLang a, CodeGenLang b) { + // return a == b; + // }); + // data_type.def("__eq__", [](TFType a, TFType b) { + // return a == b; + // }); + // + // m.attr("float32") = TFTypeFloat32; + // m.attr("int32") = TFTypeInt32; + // m.attr("uint32") = TFTypeUint32; + // m.attr("bool1") = TFTypeBool32; + // + // m.attr("cpu") = BackendType::CPU; + // m.attr("vulkan") = BackendType::Vulkan; + // m.attr("opengl") = BackendType::OpenGL; + // m.attr("codegen") = BackendType::CodeGen; + // + // m.attr("cpp_lang") = CodeGenLang::CPP; + // m.attr("glsl_lang") = CodeGenLang::GLSL; + // m.attr("hlsl_lang") = CodeGenLang::HLSL; + // + // PyTensorDefinition(m, py_tensor); + // + // // implicit conversion from TensorView to PyTensor + // py::implicitly_convertible(); + // py::implicitly_convertible(); + // py::implicitly_convertible(); + // py::implicitly_convertible(); + // + // TensorFunctionsDefinition(m); + // TensorProgramDefinition(m, tensor_program); + // TensorMemoryDefinition(m, py_tensor_mem); + // ScopeDefinitions(m, py_tensor); + // ModuleDefinitions(m); + // + // py_tensor_arg.def(py::init([](py::list shape, TFDataFormat type) { + // std::vector shape_vec; + // for (auto& s : shape) { + // shape_vec.push_back(s.cast()); + // } + // return PyTensorArg(shape_vec, type); + // }), "Create a TensorArg with the given shape and type", py::return_value_policy::take_ownership); + // + // m.def("current_backend", []() { + // return current_backend; + // }, "Get the current backend"); + // + // m.def("initialize", + // [](BackendType backend_type, const std::string& kernel_compile_options, CodeGenLang kernel_lang) { + // InitializeBackend(backend_type, kernel_compile_options, kernel_lang); + // }, py::arg("backend_type") = BackendType::CPU, py::arg("kernel_compile_options") = "", py::arg("kernel_lang") = CodeGenLang::None, "Initialize the backend"); + // + // m.def("strip_debug_info", [](bool strip) { + // strip_debug_names = strip; + // }, py::arg("strip") = true, "Strip debug info from the kernel"); + +#ifdef NDEBUG + py::print("TensorFrost module loaded!"); +#else + py::print("TensorFrost module loaded in debug mode! Expect slow performance."); +#endif + + // // TEST CODE + // TFProgram program([]() -> auto { + // Values inputs; + // Values outputs; + // Value a = 5; + // Value b = 10; + // Value f = 2.5f; + // Value g = 3.5f; + // Value c = a + b * 3; + // Value mem = memory({a, b, c}, TFFloat32); + // inputs.push_back(mem); + // vmap({a, b, c}, [&](Values ids0) { + // Value something = tofloat(mem * sin(f + g)); + // outputs.push_back(something); + // }); + // vmap({a, b, c}, [&](Values ids0) { + // Value imem = toint(mem * sin(f + g)); + // Value d = c + b + ids0[0] * imem; + // Value m0; + // vmap({c}, [&](Values ids1) { + // m0 = 0; + // }); + // if_cond(d > 0, [&]() { + // Value t = d * c * imem; + // vmap({c}, [&](Values ids1) { + // m0.Set(t * imem[{ids0[1], ids0[1], ids0[1]}]); + // }); + // }, [&]() { + // Value t = d * c / imem; + // vmap({c}, [&](Values ids1) { + // m0.Set(t / imem[{ids1[0], ids0[0], ids0[1]}]); + // }); + // }); + // vmap({c, c}, [&](Values ids1) { + // Value m = m0 * imem[{ids1[1], ids1[0], ids0[0]}]; + // outputs.push_back(m); + // }); + // }); + // return std::make_pair(inputs, outputs); + // }); + // program.Compile(); + // py::print(program.DebugPrint()); + + VulkanDefinitions(m); + + m.def("renderdoc_is_available", []() { + return IsRenderDocAvailable(); + }, "Check if RenderDoc is attached to this process"); + + m.def("renderdoc_start_capture", []() { + StartRenderDocCapture(); + }, "Start a RenderDoc capture"); + + m.def("renderdoc_end_capture", + [](bool launch_replay_ui) { + return EndRenderDocCapture(launch_replay_ui); + }, + py::arg("launch_replay_ui") = false, + "End the current RenderDoc capture and optionally launch the RenderDoc replay UI. " + "Returns the capture path if one was produced, otherwise an empty string."); +// VulkanContext ctx; +// +// const size_t N = 1024; +// // create buffers +// Buffer aBuf = createBuffer(ctx, N, sizeof(float), true); +// Buffer bBuf = createBuffer(ctx, N, sizeof(float), true); +// Buffer outBuf = createBuffer(ctx, N, sizeof(float), false); +// +// // map and write input data +// float* aPtr = static_cast(ctx.device.mapMemory(aBuf.memory, 0, aBuf.size)); +// float* bPtr = static_cast(ctx.device.mapMemory(bBuf.memory, 0, bBuf.size)); +// for (size_t i = 0; i < N; i++) { +// aPtr[i] = static_cast(i); +// bPtr[i] = static_cast(2 * i); +// } +// ctx.device.unmapMemory(aBuf.memory); +// ctx.device.unmapMemory(bBuf.memory); +// +// std::string code = R"( +// [[vk::binding(2,0)]] RWStructuredBuffer C; +// [[vk::binding(0,0)]] StructuredBuffer A; +// [[vk::binding(1,0)]] StructuredBuffer B; +// +// [shader("compute")] [numthreads(64,1,1)] +// void computeMain(uint3 tid: SV_DispatchThreadID) { +// C[tid.x] = 2.0f*A[tid.x] + B[tid.x]; +// } +// )"; +// ComputeProgram prog = createComputeProgramFromSlang(ctx, "vecadd", code, "computeMain", { &aBuf, &bBuf },{ &outBuf }); +// +// // run compute +// runProgram(ctx, prog, N); +// +// // read back result +// float* outPtr = static_cast(ctx.device.mapMemory(outBuf.memory, 0, outBuf.size)); +// bool ok = true; +// for (size_t i = 0; i < N; i++) { +// float expected = 2.0f*aPtr[i] + bPtr[i]; +// if (outPtr[i] != expected) { +// ok = false; break; +// } +// } +// ctx.device.unmapMemory(outBuf.memory); +// py::print("Compute result is ", ok ? "correct" : "incorrect"); +// +// // cleanup +// destroyComputeProgram(ctx, prog); +// destroyBuffer(ctx, aBuf); +// destroyBuffer(ctx, bBuf); +// destroyBuffer(ctx, outBuf); +} + +} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Tensor/Tensor.cpp b/TensorFrost/Tensor/Tensor.cpp deleted file mode 100644 index 000620d5..00000000 --- a/TensorFrost/Tensor/Tensor.cpp +++ /dev/null @@ -1,357 +0,0 @@ -#include "Tensor.h" -#include "Backend/CodeGen/Generators.h" - -namespace TensorFrost { - -Tensors Reverse(const Tensors& tensors) { - Tensors reversed; - for (int i = (int)tensors.size() - 1; i >= 0; i--) { - reversed.push_back(tensors[i]); - } - return reversed; -} - -vector Reverse(const vector& vec) { - vector reversed; - for (int i = (int)vec.size() - 1; i >= 0; i--) { - reversed.push_back(vec[i]); - } - return reversed; -} - -int ReverseDim(int dim, size_t dims) { - return (int)dims - dim - 1; -} - -IR* Tensor::evaluation_context_ir_ = nullptr; - -Node::~Node() { delete tensor_; } - -std::string MakeNodeErrorMessage(std::string message, std::initializer_list nodes) { - message += "\n"; - for (auto node : nodes) { - message += GetNodeString(node) + "\n"; - } - return message; -} - -vector ShapeInfo::GetShape(int default_value) const { - vector shape; - for (auto [node, _]: this->shape) { - if(node->op->class_ == OpClass::Constant) { - shape.push_back(node->tensor_->TryGetConstant()); - } else { - shape.push_back(default_value); - } - } - return shape; -} - -void ShapeInfo::ExpandDimensionsTo(int new_dim) -{ - if(new_dim <= dim) { - return; - } - Tensor& one = Tensor::Constant(1); - for(int i = dim; i < new_dim; i++) { - InsertDim(i, one.node_, true); - } -} - -float ShapeInfo::GetSizeEstimate(ShapeInfo& shape) { - vector shape_a = shape.GetShape(); - float size_a = 1.0f; - for (int i = 0; i < shape_a.size(); i++) { - size_a *= (float)shape_a[i]; - } - return size_a; -} - -void ArgumentManager::AddArgument(ArgID id, Node* node) { - if(node == nullptr) { - throw std::runtime_error("Node is null"); - } - inputs_[id] = node; - argument_types_[id] = node->format; - argument_counts_[id.first]++; - //add this node as an output of the argument - node->args.AddOutput(id, node_); -} - -void ArgumentManager::Remove(ArgID id) { - if(inputs_.find(id) == inputs_.end()) { - throw std::runtime_error("Cannot remove argument that does not exist"); - } - //remove this node as an output of the argument - inputs_[id]->args.RemoveOutput(id, node_); - inputs_.erase(id); - argument_types_.erase(id); - argument_counts_[id.first]--; -} - -void ArgumentManager::RemoveArguments(ArgType arg) { - vector to_remove; - for (auto& [id, node] : inputs_) { - if (id.first == arg) { - to_remove.push_back(id); - } - } - for (auto& id : to_remove) { - Remove(id); - } -} - -vector ArgumentManager::GetTensorVector(ArgType type) const { - vector tensors = vector(Count(type)); - for (auto& [id, node] : inputs_) { - if (id.first == type) { - tensors[id.second] = node->tensor_; - } - } - return tensors; -} - -tuple Tensor::GetOperation(const string &name, const Tensors &tensors, - bool check_shape) { - vector input_types = vector(); - for (const auto& tensor : tensors) { - input_types.push_back(tensor->node_->format); - } - - const Operation* operation = FindOperation(name); - - // check if input is valid - if (!operation->IsInputValid(input_types)) { - string error = "Input types ("; - for (int i = 0; i < input_types.size(); i++) { - error += DataTypeToString(input_types[i].type) + "(" + to_string(input_types[i].size) + ")"; - if (i < input_types.size() - 1) { - error += ", "; - } - } - error += ") are not valid for operation \"" + name + "\""; - throw std::runtime_error(error); - } - - ShapeInfo shape_info = ShapeInfo(); - - if (check_shape) - { - //check if shapes are compatible and get the final broadcasted shape - for (int i = 0; i < tensors.size(); i++) { - ShapeInfo shape_info2 = ShapeInfo(tensors[i]->node_); - auto result = CompareShape(shape_info, shape_info2, true); - shape_info = result.broadcast_shape; - } - } - - TFDataFormat output_type = operation->GetOutputType(input_types); - - return {operation, output_type, shape_info}; -} - -bool Tensor::CheckIndices(const Tensors &indices) { - for (const Tensor* index : indices) { - if (index->node_->format.type != TFType::Int && index->node_->format.type != TFType::Uint) { - return false; - } - } - return true; -} - -TFDataFormat Tensor::GetFormat() const { return node_->format; } - -void Tensor::SetData(const vector &data) const { - node_->data = data; -} - -void Tensor::SetData(uint data) const { - SetData(vector(1, data)); -} - -void Tensor::SetData(float data) const { - SetData(vector(1, AsUint(data))); -} - -void Tensor::SetData(int data) const { - SetData(vector(1, AsUint(data))); -} - -void Tensor::SetFormat(TFDataFormat type) const { - node_->format = type; -} - -void Tensor::DetachGrad() const { - node_->flags.set(NodeProp::DetachGrad); -} - -void Tensor::PassGrad() const { - node_->flags.set(NodeProp::PassGrad); -} - -void Tensor::StopFusion() const { - node_->flags.set(NodeProp::StopFusion); -} - -void Tensor::HintRange(float min, float max) const { - node_->flags.set(NodeProp::HintMinValue, (int64_t)AsUint(min)); - node_->flags.set(NodeProp::HintMaxValue, (int64_t)AsUint(max)); -} - -void Tensor::HintRange(int min, int max) const { - node_->flags.set(NodeProp::HintMinValue, (int64_t)min); - node_->flags.set(NodeProp::HintMaxValue, (int64_t)max); -} - -void Tensor::HintRange(uint min, uint max) const { - node_->flags.set(NodeProp::HintMinValue, (int64_t)min); - node_->flags.set(NodeProp::HintMaxValue, (int64_t)max); -} - -Tensor* Tensor::GetCopy(const Tensor& other, NodeArguments args) { - Tensor* copy = &CreateNode(other.node_->format, std::move(args), other.node_->name); - copy->node_->data = other.node_->data; - copy->node_->CopyProperties(other.node_); - return copy; -} - -Tensor* Tensor::GetCopy(const Tensor& other) { - NodeArguments new_args; - for (auto& [id, from] : other.node_->args.Inputs()) { - new_args[id] = from; - } - return GetCopy(other, new_args); -} - -void Tensor::SetShape(Tensors shape) const { - node_->args.RemoveArguments(ArgType::Shape); - for (int i = 0; i < shape.size(); i++) { - node_->args.AddArgument(ArgType::Shape, i, shape[i]->node_); - } -} - -Tensors Tensor::GetInputShapeTensors(Tensors shape) { - Tensors result = Tensors(); - for (int dim = 0; dim < shape.size(); dim++) { - const Tensor* tensor = shape[dim]; - //check if tensor is a negative constant - if (tensor->node_->name == "const" && (*(int*)&(tensor->node_->data[0])) < 0) - { - Tensor& mem = Static("input_shape", {TFType::Int, 32}); - //make sure its reversed on the backend - mem.node_->flags.set(NodeProp::InputShapeDim, (int64_t)(shape.size() - dim - 1)); - result.push_back(&mem); - } - else - { - result.push_back(tensor); - } - } - return result; -} - -//Get values from a tensor at the given indices -Tensor& Tensor::Load(const Tensor& tensor, const Tensors& indices, IndexingMode mode) { - Tensor& out = MemoryOp("load", &tensor, indices); - out.node_->indexing_mode_ = mode; - out.SetData(0); - out.SetDebugName(tensor.node_->debug_name); - return out; -} - -Tensor& Tensor::Store(const Tensor& tensor, const Tensor& value, - const Tensors& indices, IndexingMode mode) { - Tensor& out = MemoryOp("store", &tensor, indices, &value); - out.node_->indexing_mode_ = mode; - return out; -} - -Tensor & Tensor::ReductionOP(string name, const Tensor &tensor, int axis, bool keepdims) { - // get the shape of the tensor (all dimensions except the last one) - Tensors shape = tensor.GetShape(); - axis = GetAxis((int)shape.size(), axis); - - //check if axis is valid - if (axis < 0 || axis >= shape.size()) { - throw std::runtime_error("Invalid axis for reduction operation " + name); - } - - // remove the axis dimension - shape.erase(shape.begin() + axis); - if (shape.empty()) { - shape.push_back(&Constant(1)); - } - Tensor& op = OpShape(name, shape, &tensor); - op.node_->data = vector(1, axis); - //if(keepdims) { - // op.node_->AddFlag(NodeFlag::KeepDims); - //} - return op; -} - -Tensor & Tensor::ScanOP(string name, const Tensor &tensor, int axis) { - Tensor& op = Op(name, &tensor); - op.node_->data = vector(1, axis); - return op; -} - -bool Tensor::AreTensorsEqual(const Tensor &a, const Tensor &b) { - if(a.node_->op->class_ == OpClass::Constant && b.node_->op->class_ == OpClass::Constant) { - return a.node_->data[0] == b.node_->data[0]; - } - if(a.node_ == b.node_) { - return true; - } - return false; -} - -Tensor& Tensor::Reshape(const Tensor& tensor, const Tensors& shape, TFDataFormat format) { - Tensor& out = MemoryOpShape("reshape", shape, &tensor); - out.SetDebugName(tensor.node_->debug_name); - if(format.type != TFType::None) { - out.node_->format = format; - } else { - out.node_->format = tensor.node_->format; - } - return out; -} - -Tensor & Tensor::Assert(const Tensor &tensor, const Tensors &shape, TFDataFormat type) { - Tensor& out = MemoryOpShape("assert", shape, &tensor); - out.SetDebugName(tensor.node_->debug_name); - out.node_->format = type; - return out; -} - -void Tensor::SetDebugName(const string& name) const -{ - if (name != "") { - node_->debug_name = name; - } - - if (strip_debug_names) { - static int count = 0; - - node_->debug_name = "tensor_" + std::to_string(count++); - } -} - -void Tensor::BeginRegion(const string& name) { - Tensor& t = Static("region_begin", {TFType::None, 0}); - t.SetDebugName(name); -} - -void Tensor::EndRegion(const string& name) { - Tensor& t = Static("region_end", {TFType::None, 0}); - t.SetDebugName(name); -} - -const Tensor* Node::GetTensor() const { - if (tensor_->node_ != this) { - throw std::runtime_error("Fatal Error: Tensor node does not match"); - } - return tensor_; -} - - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Tensor/Tensor.h b/TensorFrost/Tensor/Tensor.h deleted file mode 100644 index 40deee5e..00000000 --- a/TensorFrost/Tensor/Tensor.h +++ /dev/null @@ -1,1228 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -// for FLT_MAX, INT_MAX, etc. -#include -#include - -#include - -#include "Compiler/Graph/IR.h" -#include "Utility/Utility.h" - -namespace TensorFrost { - -using Tensors = vector; - -Tensors Reverse(const Tensors& tensors); -vector Reverse(const vector& vec); -int ReverseDim(int dim, size_t dims); - -class Tensor { - private: - static IR* evaluation_context_ir_; - - static Tensor& CreateNode(TFDataFormat type, NodeArguments args, string name) { - if (evaluation_context_ir_ == nullptr) { - throw std::runtime_error( - "Evaluation context has not been set. Are you doing operations " - "without compiling first?"); - } - - //TODO: merge Tensor and Node classes - auto* tensor = new Tensor(); - tensor->node_ = evaluation_context_ir_->AddNode(tensor, std::move(args), std::move(name), type); - return *tensor; - } - - static void AddArgument(NodeArguments& arguments, const Tensor* tensor, - ArgType type, int index = 0) { - arguments[ArgID(type, index)] = tensor->node_; - } - - static void AddArguments(NodeArguments& arguments, const Tensors& tensors, - ArgType type) { - for (int i = 0; i < tensors.size(); i++) { - AddArgument(arguments, tensors[i], type, i); - } - } - - static void AddArguments(NodeArguments& arguments, const NodeArguments& toadd) { - for (const auto& i : toadd) { - arguments[i.first] = i.second; - } - } - - static bool AssertTensorShape(const Tensor* a, const Tensor* b, bool throw_error = true) { - return CompareShape(a->node_, b->node_, throw_error).compatible; - } - - static tuple GetOperation(const string& name, - const Tensors& tensors, bool check_shape = true); - - template - static Tensor& Op(std::string op, const Args*... args) { - op = RemoveSpaces(op); - - if (op.empty()) { - throw std::runtime_error("Operation name cannot be empty"); - } - - // convert the parameter pack to a std::vector - Tensors tensors = {args...}; - - // get the operation and output type - auto [operation, output_type, shape_info] = GetOperation(op, tensors); - - // create argument list - NodeArguments arguments = NodeArguments(); - - AddArguments(arguments, tensors, ArgType::Input); - - AddArguments(arguments, shape_info.GetArguments()); - - return CreateNode(output_type, arguments, op); - } - -public: - static Tensor& OpShape(std::string op, Tensors shape, Tensors tensors) { - op = RemoveSpaces(op); - - if (op.empty()) { - throw std::runtime_error("Operation name cannot be empty"); - } - - // get the operation and output type - auto [operation, output_type, shape_info] = GetOperation(op, tensors, false); - - // create argument list - NodeArguments arguments = NodeArguments(); - - AddArguments(arguments, tensors, ArgType::Input); - AddArguments(arguments, shape, ArgType::Shape); - - return CreateNode(output_type, arguments, op); - } - - template - static Tensor& OpShape(std::string op, Tensors shape, const Args*... args) { - // convert the parameter pack to a std::vector - Tensors tensors = {args...}; - - return OpShape(op, shape, tensors); - } - - template - static Tensor& MemoryOpShape(std::string op, Tensors shape, const Tensor* memory, const Args*... args) { - op = RemoveSpaces(op); - - if (op.empty()) { - throw std::runtime_error("Memory operation name cannot be empty"); - } - - // convert the parameter pack to a std::vector - Tensors tensors = {args...}; - - // get the operation and output type - auto [operation, output_type, shape_info] = GetOperation(op, tensors); - - // create argument list - NodeArguments arguments = NodeArguments(); - - AddArgument(arguments, memory, ArgType::Memory); - AddArguments(arguments, tensors, ArgType::Input); - AddArguments(arguments, shape, ArgType::Shape); - - return CreateNode(output_type, arguments, op); - } - - static bool CheckIndices(const Tensors& indices); - - template - static Tensor& MemoryOp(string op, const Tensor* memory, - const Tensors indices, const Args*... args) { - op = RemoveSpaces(op); - - if (op.empty()) { - throw std::runtime_error("Memory operation name cannot be empty"); - } - - if(!CheckIndices(indices)) { - throw std::runtime_error("Tensor indices must be integers"); - } - - //check if memory op is local - bool is_local = false; - if(memory->node_->op->HasAllTypes(OpProp::LocalMemory)) { - is_local = true; - } - - //if local, can only have 1 index - if(is_local && indices.size() > 1) { - throw std::runtime_error("Local memory operations can only have 1 index"); - } - - //if global, can only have up to memory dimension indices - size_t memory_dim = memory->GetDimension(); - if(!is_local && indices.size() > memory_dim) { - throw std::runtime_error("Too many indices for memory operation, memory has " + std::to_string(memory_dim) + " dimensions, while " + std::to_string(indices.size()) + " indices were provided"); - } - - // convert the parameter pack to a std::vector - Tensors tensors = {args...}; - - // get the operation and output type - auto [operation, output_type, shape_info] = GetOperation(op, tensors); - - if (operation->HasAllTypes(OpProp::Modifier)) - { - memory->node_->flags.set(NodeProp::Modified); - } - - // create argument list - NodeArguments arguments = NodeArguments(); - - AddArgument(arguments, memory, ArgType::Memory); - AddArguments(arguments, tensors, ArgType::Input); - AddArguments(arguments, indices, ArgType::Index); - - // get an input node that has shape arguments - NodeArguments shape_arguments = shape_info.GetArguments(); - - //use index shape instead if no input shape is found - if (shape_arguments.empty()) - { - for (const Tensor* index : indices) - { - shape_arguments = index->node_->args.GetArguments(ArgType::Shape); - if (!shape_arguments.empty()) { - break; - } - } - } - - //if no indices or inputs exist, use memory shape - if (indices.empty() && tensors.empty()) - { - shape_arguments = memory->node_->args.GetArguments(ArgType::Shape); - } - - AddArguments(arguments, shape_arguments); - - if (op == "load") output_type = memory->GetFormat(); - - Tensor& output = CreateNode(output_type, arguments, op); - if(is_local) output.node_->flags.set(NodeProp::LocalMemoryOp); - return output; - } - - static Tensor& Static(string op, const NodeArguments& shape, - const TFDataFormat format) { - op = RemoveSpaces(op); - - if (op.empty()) { - throw std::runtime_error("Static operation name cannot be empty"); - } - - const Operation* operation = FindOperation(op); - // check if output is valid - if (!operation->IsOutputValid(format)) { - throw std::runtime_error("Type " + DataTypeToString(format.type) + "(" + std::to_string(format.size) + ") is not valid for operation \"" + op + "\""); - } - NodeArguments arguments = NodeArguments(); - AddArguments(arguments, shape); - return CreateNode(format, arguments, op); - } - - static Tensor& Static(const string& op, const Tensors& shape, - const TFDataFormat type) { - NodeArguments arguments = NodeArguments(); - AddArguments(arguments, shape, ArgType::Shape); - return Static(op, arguments, type); - } - - static Tensor& Static(const string& op, const TFDataFormat type) { - return Static(op, NodeArguments(), type); - } - - static void SetEvaluationContext(IR* ir) { - if (evaluation_context_ir_ != nullptr && ir != nullptr) { - throw std::runtime_error("Evaluation context change is forbidden."); - } - evaluation_context_ir_ = ir; - } - - static IR* GetEvaluationContext() { - return evaluation_context_ir_; - } - - string GetConstantString() const; - - static Tensor& CustomOperation(const string & name, Tensors inputs, Tensors shape) { - return OpShape(name, shape, inputs); - } - - Node* node_ = nullptr; - - TFDataFormat GetFormat() const; - void SetData(const vector& data) const; - void SetData(uint data) const; - void SetData(float data) const; - void SetData(int data) const; - void SetFormat(TFDataFormat type) const; - void DetachGrad() const; - void PassGrad() const; - void StopFusion() const; - void HintRange(float min, float max) const; - void HintRange(int min, int max) const; - void HintRange(uint min, uint max) const; - - static Tensor* GetCopy(const Tensor& other, NodeArguments args); - - static Tensor* GetCopy(const Tensor& other); - - void SetMemoryType(NodeProp memory_type, int index = 0) const { - node_->SetMemoryType(memory_type, index); - } - - int GetDimension() const { - ShapeInfo shape_info = ShapeInfo(node_); - return shape_info.dim; - } - - Tensors GetShape() const { - ShapeInfo shape_info = ShapeInfo(node_); - return shape_info.GetTensors(); - } - - Tensors GetReverseShape() const { - Tensors shape = GetShape(); - std::reverse(shape.begin(), shape.end()); - return shape; - } - - ShapeInfo GetShapeInfo() const { - return ShapeInfo(node_); - } - - void SetShape(Tensors shape) const; - - int TryGetConstant() const { - if (node_->name != "const") { - return -1; - } - return AsInt(node_->data[0]); - } - - // tensor factory methods - static Tensors GetConstantShape(const vector& shape) { - Tensors result = Tensors(); - for (int i : shape) { - result.push_back(&Constant(i)); - } - return result; - } - - static Tensor& Constant(float value) { - Tensor& output = Static("const", TFTypeFloat32); - output.SetData(AsUint(value)); - return output; - } - static Tensor& Constant(int value) { - Tensor& output = Static("const", TFTypeInt32); - output.SetData(AsUint(value)); - return output; - } - static Tensor& Constant(uint value) { - Tensor& output = Static("const", TFTypeUint32); - output.SetData(value); - return output; - } - static Tensor& Constant(bool value) { - Tensor& output = Static("const", TFTypeBool32); - output.SetData(value); - return output; - } - static Tensor& Constant(uint value, TFDataFormat type) { - Tensor& output = Static("const", type); - output.SetData(value); - return output; - } - static Tensor& Constant(uint value, const Tensors& shape, TFDataFormat type) { - NodeArguments arguments = NodeArguments(); - AddArguments(arguments, shape, ArgType::Shape); - Tensor& output = Static("const", arguments, type); - output.SetData(value); - return output; - } - - static Tensor& Constant(const Tensors& shape, float value) { - NodeArguments arguments = NodeArguments(); - AddArguments(arguments, shape, ArgType::Shape); - Tensor& output = Static("const", arguments, TFTypeFloat32); - output.SetData(value); - return output; - } - static Tensor& Constant(const vector& shape, float value) { - return Constant(GetConstantShape(shape), value); - } - static Tensor& Constant(const Tensors& shape, int value) { - NodeArguments arguments = NodeArguments(); - AddArguments(arguments, shape, ArgType::Shape); - Tensor& output = Static("const", arguments, TFTypeInt32); - output.SetData(value); - return output; - } - static Tensor& Constant(const vector& shape, int value) { - return Constant(GetConstantShape(shape), value); - } - - static Tensor& Constant(const Tensors& shape, uint value) { - NodeArguments arguments = NodeArguments(); - AddArguments(arguments, shape, ArgType::Shape); - Tensor& output = Static("const", arguments, TFTypeUint32); - output.SetData(value); - return output; - } - - static Tensor& Constant(const vector& shape, uint value) { - return Constant(GetConstantShape(shape), value); - } - static Tensor& Constant(const Tensors shape, uint value, TFDataFormat type) { - NodeArguments arguments = NodeArguments(); - AddArguments(arguments, shape, ArgType::Shape); - Tensor& output = Static("const", arguments, type); - output.SetData(value); - return output; - } - - static Tensors GetShapeTensors(const vector& shape) { - Tensors result = Tensors(); - for (int i : shape) { - result.push_back(&Constant(i)); - } - return result; - } - - static Tensor& Memory(const TFDataFormat type) { return Static("memory", type); } - static Tensor& Memory(const Tensors& shape, - const TFDataFormat type = TFTypeFloat32) { - return Static("memory", shape, type); - } - static Tensor& Memory(const NodeArguments& shape, - const TFDataFormat type = TFTypeFloat32) { - return Static("memory", shape, type); - } - static Tensor& Memory(const vector& shape, - const TFDataFormat type = TFTypeFloat32) { - return Memory(GetShapeTensors(shape), type); - } - - static Tensor& LocalMemory(const int size, const TFDataFormat type) { - Tensor& output = Static("local_memory", type); - output.SetData(size); - return output; - } - - static Tensor& GroupMemory(const int size, const TFDataFormat type) { - Tensor& output = Static("group_memory", type); - output.SetData(size); - return output; - } - - static void GroupBarrier() { - Op("group_barrier"); - } - - static Tensors GetInputShapeTensors(Tensors shape); - - static Tensor& Input(const TFDataFormat type = TFTypeFloat32) { - Tensor& output = Memory(type); - output.SetMemoryType(NodeProp::InputMemory); - return output; - } - static Tensor& Input(const Tensors& shape, - const TFDataFormat type = TFTypeFloat32) { - Tensor& output = Memory(GetInputShapeTensors(shape), type); - output.SetMemoryType(NodeProp::InputMemory); - return output; - } - static Tensor& Input(const vector& shape, - const TFDataFormat type = TFTypeFloat32) { - return Input(GetShapeTensors(shape), type); - } - - static Tensor& Index(NodeArguments shape, int dim) { - Tensor& output = Static("dim_id", shape, TFTypeInt32); - output.SetData(dim); - output.SetFormat(TFTypeInt32); - return output; - } - - static Tensor& Index(Tensors shape, int dim) { - Tensor& output = Static("dim_id", shape, TFTypeInt32); - output.SetData(dim); - output.SetFormat(TFTypeInt32); - return output; - } - - static Tensor& Index(const vector& shape, int dim) { - return Index(GetConstantShape(shape), dim); - } - - static Tensors Indices(Tensors shape) { - int dims = (int)shape.size(); - Tensors indices = Tensors(); - for (int i = 0; i < dims; i++) { - indices.push_back(&Index(shape, i)); - } - return indices; - } - - static Tensor& FlatIndex(Tensors shape, Tensors indices) { - int memory_dim = (int)shape.size(); - if(memory_dim == 0) return Constant(0); - // compute the flat index (C-order) - Tensor* flat_index = const_cast(indices[0]); - for (int i = 1; i < memory_dim; i++) { - flat_index = &(*flat_index * *shape[i]); - flat_index = &(*flat_index + *indices[i]); - } - return *flat_index; - } - - static Tensors IndicesFromFlatIndex(const Tensor* index, Tensors shape) - { - size_t dims = shape.size(); - Tensors indices = Tensors(dims); - Tensors sizes = Tensors(dims); - sizes[0] = shape[0]; - for (size_t i = 1; i < dims - 1; i++) { - sizes[i] = &(*sizes[i - 1] * *shape[i]); - } - - Tensor* temp; - for (size_t i = 0; i < dims; i++) { - Tensor* idx0 = const_cast(index); - if (i < dims - 1) { - idx0 = &(*idx0 / *sizes[dims - i - 2]); - } - if (i > 0) { - temp = &(*temp * *shape[dims - i - 1]); - idx0 = &(*idx0 - *temp); - if (i != dims - 1) temp = &(*temp + *idx0); - } else { - temp = idx0; - } - indices[dims - i - 1] = idx0; - } - - return indices; - } - - static Tensor& ElementIndex(Tensors shape) { - return FlatIndex(shape, Indices(shape)); - } - - static Tensor& GetSeed(Tensors shape, const Tensor& seed) { - Tensor* full_seed = &const_cast(seed); - if(full_seed->node_->format.type != TFType::Uint) { - full_seed = &asuint(*full_seed); //convert seed to uint - } - full_seed = &(touint(ElementIndex(shape)) + *full_seed * Constant(2654435761u)); - return *full_seed; - } - - static Tensor& Hash(Tensors shape, const Tensor& seed) { - return pcg(GetSeed(shape, seed)); - } - - static Tensor& Random(Tensors shape, const Tensor& seed) { - return pcgf(GetSeed(shape, seed)); - } - - Tensor& BlockIndex() const { - Tensor& output = Static( - "block_id", node_->args.GetArguments(ArgType::Shape), TFTypeInt32); - output.SetFormat(TFTypeInt32); - return output; - } - - Tensor& BlockThreadIndex(int i) const { - Tensor& output = Static( - "block_thread_id", node_->args.GetArguments(ArgType::Shape), TFTypeInt32); - output.SetFormat(TFTypeInt32); - output.SetData(i); - return output; - } - - static Tensor& Load(const Tensor& tensor, const Tensors& indices = Tensors(), - IndexingMode mode = IndexingMode::Clamp); - - static Tensor& Deallocate(const Tensor& tensor) { - return MemoryOp("deallocate", &tensor, {}); - } - - Tensor& Index(int dim) const { - Tensor& output = Static("dim_id", node_->args.GetArguments(ArgType::Shape), TFTypeInt32); - output.SetData(dim); - output.SetFormat(TFTypeInt32); - return output; - } - - Tensors Indices() const { - return Indices(GetShape()); - } - - static Tensor& Store(const Tensor& tensor, const Tensor& value, - const Tensors& indices = Tensors(), IndexingMode mode = IndexingMode::Clamp); - - void Set(const Tensor& value) const { - //check if memory and value shapes are compatible - ShapeCompareResult shape_result = CompareShape(node_, value.node_, true); - if (!shape_result.compatible) { - throw std::runtime_error("Cannot set tensor with incompatible shape"); - } - MemoryOp("set", this, {}, &value); - //update the shape of the tensor - SetShape(shape_result.broadcast_shape.GetTensors()); - } - - static void ScatterAdd(const Tensor& tensor, const Tensor& value, - const Tensors& indices, IndexingMode mode = IndexingMode::Clamp) { - MemoryOp("InterlockedAdd", &tensor, indices, &value); - } - - static Tensor& ScatterAddPrev(const Tensor& tensor, const Tensor& value, - const Tensors& indices, IndexingMode mode = IndexingMode::Clamp) { - Tensor& a = MemoryOp("InterlockedAdd_Prev", &tensor, indices, &value); - a.node_->indexing_mode_ = mode; - return a; - } - - static void ScatterMax(const Tensor& tensor, const Tensor& value, - const Tensors& indices, IndexingMode mode = IndexingMode::Clamp) { - Tensor& a = MemoryOp("InterlockedMax", &tensor, indices, &value); - a.node_->indexing_mode_ = mode; - } - - static void ScatterMin(const Tensor& tensor, const Tensor& value, - const Tensors& indices, IndexingMode mode = IndexingMode::Clamp) { - Tensor& a = MemoryOp("InterlockedMin", &tensor, indices, &value); - a.node_->indexing_mode_ = mode; - } - - static void ScatterOr(const Tensor& tensor, const Tensor& value, - const Tensors& indices, IndexingMode mode = IndexingMode::Clamp) { - Tensor& a = MemoryOp("InterlockedOr", &tensor, indices, &value); - a.node_->indexing_mode_ = mode; - } - - static void ScatterAnd(const Tensor& tensor, const Tensor& value, - const Tensors& indices, IndexingMode mode = IndexingMode::Clamp) { - Tensor& a = MemoryOp("InterlockedAnd", &tensor, indices, &value); - a.node_->indexing_mode_ = mode; - } - - static void ScatterXor(const Tensor& tensor, const Tensor& value, - const Tensors& indices, IndexingMode mode = IndexingMode::Clamp) { - Tensor& a = MemoryOp("InterlockedXor", &tensor, indices, &value); - a.node_->indexing_mode_ = mode; - } - - static int GetAxis(int dims, int axis) { - if (axis < 0) { - axis = dims + axis; - } - return axis; - } - - static Tensor& ReductionOP(string name, const Tensor& tensor, int axis = 0, bool keepdims = false); - static Tensor& ScanOP(string name, const Tensor& tensor, int axis = 0); - - static Tensor& Sum(const Tensor& tensor, int axis = 0) { - return ReductionOP("dim_sum", tensor, axis); - } - - static Tensor& Norm(const Tensor& tensor, int axis = 0) { - return ReductionOP("dim_norm", tensor, axis); - } - - static Tensor& Mean(const Tensor& tensor, int axis = 0) { - return ReductionOP("dim_mean", tensor, axis); - } - - static Tensor& Max(const Tensor& tensor, int axis = 0) { - return ReductionOP("dim_max", tensor, axis); - } - - static Tensor& Any(const Tensor& tensor, int axis = 0) { - return ReductionOP("dim_any", tensor, axis); - } - - static Tensor& All(const Tensor& tensor, int axis = 0) { - return ReductionOP("dim_all", tensor, axis); - } - - static Tensor& Min(const Tensor& tensor, int axis = 0) { - return ReductionOP("dim_min", tensor, axis); - } - - static Tensor& PrefixSum(const Tensor& tensor, int axis = 0) { - return ScanOP("dim_prefix_sum", tensor, axis); - } - - static Tensor& Reverse(const Tensor& tensor, int axis = 0) { - Tensors shape = tensor.GetShape(); - int dims = (int)shape.size(); - axis = GetAxis(dims, axis); - Tensor& output = OpShape("dim_reverse", shape, &tensor); - output.SetData(axis); - return output; - } - - static Tensor& Repeat(const Tensor& tensor, const Tensor& repeats, int axis = 0) { - //check if repeats is a scalar - if (repeats.GetDimension() != 0) { - throw std::runtime_error("Repeats argument must be a scalar"); - } - int dims = (int)tensor.GetDimension(); - axis = GetAxis(dims, axis); - Tensors shape = tensor.GetShape(); - Tensors new_shape = Tensors(); - for (int i = 0; i < dims; i++) { - if (i == axis) { - new_shape.push_back(&(*shape[i] * repeats)); - } else { - new_shape.push_back(shape[i]); - } - } - Tensor& output = OpShape("dim_repeat", new_shape, &tensor); - output.SetData(axis); - return output; - } - - static Tensor& SplitDim(const Tensor& tensor, int split_size = 128, int axis = 0) { - ShapeInfo shapeinfo = tensor.GetShapeInfo(); - int dims = shapeinfo.dim; - Tensors shape = shapeinfo.GetTensors(); - axis = GetAxis(dims, axis); - Tensors new_shape = Tensors(); - for (int i = 0; i < dims; i++) { - if (i == axis) { - new_shape.push_back(&Tensor::Constant(split_size)); - new_shape.push_back(&((*shape[i] + Tensor::Constant(split_size - 1)) / Tensor::Constant(split_size))); - } else { - new_shape.push_back(shape[i]); - } - } - Tensor& output = OpShape("dim_split", new_shape, &tensor); - output.SetData({(uint)axis, (uint)split_size}); - return output; - } - - static Tensor& MergeDim(const Tensor& tensor, int axis = 0, const Tensor* target_size = nullptr) { - ShapeInfo shapeinfo = tensor.GetShapeInfo(); - int dims = shapeinfo.dim; - Tensors shape = shapeinfo.GetTensors(); - axis = GetAxis(dims, axis); - if(axis == 0) axis = 1; - const Tensor* target_size_tensor = nullptr; - if(target_size == nullptr) { - target_size_tensor = &(*shape[axis] * *shape[axis+1]); - } else { - target_size_tensor = target_size; - } - axis = GetAxis(dims, axis); - Tensors new_shape = Tensors(); - for (int i = 0; i < dims; i++) { - if(i == axis) { - new_shape.push_back(target_size_tensor); - } else if(i != axis+1) { - new_shape.push_back(shape[i]); - } - } - Tensor& output = OpShape("dim_merge", new_shape, &tensor); - output.SetData(axis); - return output; - } - - static Tensor& Transpose(const Tensor& tensor, const int axis1 = 1, const int axis2 = 0) { - ShapeInfo shapeinfo = tensor.GetShapeInfo(); - - int dims = std::max(std::max(axis1+1, axis2+1), std::max(shapeinfo.dim, -std::min(axis1, axis2))); - int a1 = GetAxis(dims, axis1); - int a2 = GetAxis(dims, axis2); - shapeinfo.ExpandDimensionsTo(dims); - Tensors shape = shapeinfo.GetTensors(); - //swap the axes - std::swap(shape[a1], shape[a2]); - Tensor& output = OpShape("transpose", shape, &tensor); - //add data - output.SetData({AsUint(a1), AsUint(a2)}); - return output; - } - - //dot product of - static Tensor& Dot(const Tensor& tensor1, const Tensor& tensor2, int axis = 0) { - Tensors shape = tensor1.GetShape(); - int dims = (int)shape.size(); - axis = GetAxis(dims, axis); - shape.erase(shape.begin() + axis); - Tensor& output = OpShape("dot", shape, &tensor1, &tensor2); - output.SetData(axis); - return output; - } - - static Tensor& Unsqueeze(const Tensor& tensor, int axis = 0) { - Tensors shape = tensor.GetShape(); - int dims = (int)shape.size(); - if(axis < 0) { - axis = dims + axis + 1; - } - axis = std::max(0, std::min(dims, axis)); - shape.insert(shape.begin() + axis, &Constant(1)); - Tensor& output = OpShape("unsqueeze", shape, &tensor); - output.SetData(axis); - return output; - } - - static bool AreTensorsEqual(const Tensor& a, const Tensor& b); - - static Tensor& Squeeze(const Tensor& tensor, int axis = 0) { - Tensors shape = tensor.GetShape(); - int dims = (int)shape.size(); - axis = GetAxis(dims, axis); - if (shape[axis]->TryGetConstant() != 1) { - throw std::runtime_error("Cannot squeeze a dimension that is not 1"); - } - shape.erase(shape.begin() + axis); - Tensor& output = OpShape("squeeze", shape, &tensor); - output.SetData(axis); - return output; - } - - //takes two tensors [T1, T2, ..., Tn, M, N] and [Tm, .., Tn, N, K] and returns [T1, T2, ..., Tm, M, K] - static Tensor& Matmul(const Tensor& a, const Tensor& b) { - ShapeInfo shape_a = a.GetShapeInfo(); - ShapeInfo shape_b = b.GetShapeInfo(); - - if (shape_a.dim < 2 && shape_b.dim < 2) { - throw std::runtime_error("Matrix multiplication requires at least one 2D tensor"); - } - - if(shape_a.dim < 2) { - shape_a.ExpandDimensionsTo(2); - } - if(shape_b.dim < 2) { - shape_b.ExpandDimensionsTo(2); - } - - Tensors shape_a_tensors = shape_a.GetTensors(); - Tensors shape_b_tensors = shape_b.GetTensors(); - - //get shape of the result - Tensors shape_c = Tensors(); - int dim_a = shape_a.dim; - int dim_b = shape_b.dim; - int max_dim = 0; - Tensors max_shape = Tensors(); - //get the shape with most dimensions - if (dim_a < dim_b) { - max_dim = dim_b; - max_shape = shape_b_tensors; - } else { - max_dim = dim_a; - max_shape = shape_a_tensors; - } - - shape_c.push_back(shape_b_tensors[0]); - shape_c.push_back(shape_a_tensors[1]); - for (int i = 2; i < max_dim; i++) { - shape_c.push_back(max_shape[i]); - } - ShapeDimCompareResult result = CompareShapeDim(shape_a_tensors[0]->node_, shape_b_tensors[1]->node_); - if (!result.compatible) { - throw std::runtime_error("Inner dimensions of the matrices must match"); - } - - Tensor& output = OpShape("matmul", shape_c, &a, &b); - return output; - } - - static Tensor& Reshape(const Tensor &tensor, const Tensors &shape, TFDataFormat format = TFTypeNone); - static Tensor& Assert(const Tensor& tensor, const Tensors& shape, TFDataFormat type = TFTypeFloat32); - - Tensors enter_tensors = Tensors(); - bool already_entered = false; - - std::variant Enter() - { - if(!node_->op->HasAllTypes(OpProp::HasChildren)) { - throw std::runtime_error("The node of type " + node_->name + " cannot have children"); - } - - if(already_entered) { - throw std::runtime_error("Already entered node scope"); - } - evaluation_context_ir_->BeginScopeLastChild(node_); //begin at the last child - already_entered = true; - if(enter_tensors.size() > 0) { - return enter_tensors; //if we have some special info, like indices of kernel threads - } else { - return const_cast(this); - } - } - - void Exit() - { - evaluation_context_ir_->EndScope(); - } - - static Tensor& Loop(const Tensor& start, const Tensor& end, const Tensor& step) - { - return Op("loop", &start, &end, &step); - } - - static void Loop(const Tensor& start, const Tensor& end, const Tensor& step, - const function& body) { - // create the loop - Tensor& loop = Loop(start, end, step); - - evaluation_context_ir_->ExecuteExpressionFirstChild(loop.node_, [&]() { - // create the body - body(loop); - }); - } - - static Tensor& If(const Tensor& condition) { - // create the if - Tensor& if_tensor = Op("if", &condition); - return if_tensor; - } - - static void If(const Tensor& condition, - const std::function& body) { - // create the if - Tensor& if_tensor = If(condition); - - evaluation_context_ir_->ExecuteExpressionFirstChild(if_tensor.node_, [&]() { - // create the body - body(); - }); - } - - static void If(const Tensor& condition, const std::function& true_body, - const std::function& false_body) { - If(condition, true_body); - If(!condition, false_body); - } - - static void Vmap(const Tensors inputs, const Tensors shape, const std::function& body) { - // create the if - Tensor& vmap_main = OpShape("vmap", Tensors(), inputs); - - evaluation_context_ir_->ExecuteExpressionFirstChild(vmap_main.node_, [&]() { - // create the body - body(); - }); - } - - static Tensor& Kernel(const Tensors shape, vector group_size = {}) { - // create the kernel - Tensor& kernel = Static("kernel", shape, TFTypeNone); - evaluation_context_ir_->ExecuteExpressionFirstChild(kernel.node_, [&]() { - for (int i = 0; i < shape.size(); i++) { - kernel.enter_tensors.push_back(&Index(shape, i)); //thread indices - } - }); - kernel.node_->group_size = group_size; - return kernel; - } - - static Tensor& Kernel(const Tensors shape, const std::function& body, vector group_size = {}) { - // create the kernel - Tensor& kernel = Kernel(shape); - - evaluation_context_ir_->ExecuteExpressionLastChild(kernel.node_, [&]() { - // create the body - body(kernel.enter_tensors); - }); - - kernel.node_->group_size = group_size; - return kernel; - } - - static Tensor& Kernel(const NodeArguments& shape) - { - // create the kernel - Tensor& kernel = Static("kernel", shape, TFTypeNone); - return kernel; - } - - static void Break() { - // create the break - Tensor& break_tensor = Static("break", TFTypeNone); - } - - static void Continue() { - // create the continue - Tensor& continue_tensor = Static("continue", TFTypeNone); - } - - // destructor - ~Tensor() = default; - - Tensor& operator-() const { return Op("neg", this); } - Tensor& operator!() const { - if(node_->format.type == TFType::Bool) { - return Op("notb", this); - } else { - return Op("not", this); - } - } - Tensor& operator~() const { - if(node_->format.type == TFType::Bool) { - return Op("notb", this); - } else { - return Op("not", this); - } - } - - Tensor& operator+(const Tensor& other) const { - return Op("add", this, &other); - } - - Tensor& operator-(const Tensor& other) const { - return Op("sub", this, &other); - } - - Tensor& operator*(const Tensor& other) const { - return Op("mul", this, &other); - } - - Tensor& operator/(const Tensor& other) const { - return Op("div", this, &other); - } - - Tensor& operator%(const Tensor& other) const { - return Op("mod", this, &other); - } - - Tensor& operator>(const Tensor& other) const { - return Op("gt", this, &other); - } - - Tensor& operator<(const Tensor& other) const { - return Op("lt", this, &other); - } - - Tensor& operator>=(const Tensor& other) const { - return Op("gte", this, &other); - } - - Tensor& operator<=(const Tensor& other) const { - return Op("lte", this, &other); - } - - Tensor& operator==(const Tensor& other) const { - return Op("eq", this, &other); - } - - Tensor& operator!=(const Tensor& other) const { - return Op("neq", this, &other); - } - - Tensor& operator&&(const Tensor& other) const { - return Op("and", this, &other); - } - - Tensor& operator||(const Tensor& other) const { - return Op("or", this, &other); - } - - Tensor& operator&(const Tensor& other) const { - return Op("and", this, &other); - } - - Tensor& operator|(const Tensor& other) const { - return Op("or", this, &other); - } - - Tensor& operator^(const Tensor& other) const { - return Op("xor", this, &other); - } - - Tensor& operator<<(const Tensor& other) const { - return Op("lshift", this, &other); - } - - Tensor& operator>>(const Tensor& other) const { - return Op("rshift", this, &other); - } - - void operator=(const Tensor& other) = delete; - - static Tensor& copy(const Tensor& tensor) { - return Op("copy", &tensor); - } - - static Tensor& sin(const Tensor& x) { return Op("sin", &x); } - static Tensor& cos(const Tensor& x) { return Op("cos", &x); } - static Tensor& tan(const Tensor& x) { return Op("tan", &x); } - static Tensor& asin(const Tensor& x) { return Op("asin", &x); } - static Tensor& acos(const Tensor& x) { return Op("acos", &x); } - static Tensor& atan(const Tensor& x) { return Op("atan", &x); } - static Tensor& sinh(const Tensor& x) { return Op("sinh", &x); } - static Tensor& cosh(const Tensor& x) { return Op("cosh", &x); } - static Tensor& tanh(const Tensor& x) { return Op("tanh", &x); } - static Tensor& asinh(const Tensor& x) { return Op("asinh", &x); } - static Tensor& acosh(const Tensor& x) { return Op("acosh", &x); } - static Tensor& atanh(const Tensor& x) { return Op("atanh", &x); } - static Tensor& exp(const Tensor& x) { return Op("exp", &x); } - static Tensor& log(const Tensor& x) { return Op("log", &x); } - static Tensor& log2(const Tensor& x) { return Op("log2", &x); } - static Tensor& exp2(const Tensor& x) { return Op("exp2", &x); } - static Tensor& sqrt(const Tensor& x) { return Op("sqrt", &x); } - static Tensor& sqr(const Tensor& x) { return Op("sqr", &x); } - static Tensor& rsqrt(const Tensor& x) { return Op("rsqrt", &x); } - static Tensor& rcp(const Tensor& x) { return Op("rcp", &x); } - static Tensor& abs(const Tensor& x) { return Op("abs", &x); } - static Tensor& sign(const Tensor& x) { return Op("sign", &x); } - static Tensor& floor(const Tensor& x) { return Op("floor", &x); } - static Tensor& ceil(const Tensor& x) { return Op("ceil", &x); } - static Tensor& round(const Tensor& x) { return Op("round", &x); } - static Tensor& trunc(const Tensor& x) { return Op("trunc", &x); } - static Tensor& frac(const Tensor& x) { return Op("frac", &x); } - - static Tensor& pcg(const Tensor& x) { return Op("pcg", &x); } - static Tensor& pcgf(const Tensor& x) { return Op("pcgf", &x); } - - static Tensor& reversebits(const Tensor& x) { return Op("reversebits", &x); } - - static Tensor& tofloat(const Tensor& x) { return Op("float", &x); } - static Tensor& toint(const Tensor& x) { return Op("int", &x); } - static Tensor& touint(const Tensor& x) { return Op("uint", &x); } - static Tensor& tobool(const Tensor& x) { return Op("bool", &x); } - - static Tensor& asfloat(const Tensor& x) { return Op("asfloat", &x); } - static Tensor& asint(const Tensor& x) { return Op("asint", &x); } - static Tensor& asuint(const Tensor& x) { return Op("asuint", &x); } - - static Tensor& clamp(const Tensor& x, const Tensor& min, const Tensor& max) { - return Op("clamp", &x, &min, &max); - } - - static Tensor& pow(const Tensor& x, const Tensor& y) { - return Op("pow", &x, &y); - } - - static Tensor& min(const Tensor& x, const Tensor& y) { - return Op("min", &x, &y); - } - - static Tensor& max(const Tensor& x, const Tensor& y) { - return Op("max", &x, &y); - } - - static Tensor& mod(const Tensor& x, const Tensor& y) { - return Op("mod", &x, &y); - } - - static Tensor& modf(const Tensor& x, const Tensor& y) { - return Op("modf", &x, &y); - } - - static Tensor& atan2(const Tensor& x, const Tensor& y) { - return Op("atan2", &x, &y); - } - - static Tensor& grad(const Tensor& x, const Tensor& wrt) { - if(x.node_->op->HasAllTypes(OpProp::Nondiff) && !x.node_->flags.has(NodeProp::Modified)) { - throw std::runtime_error("Cannot compute gradient of a non-differentiable operation"); - } - return OpShape("backwards_grad", wrt.GetShape(), &x, &wrt); - } - - static Tensor& lerp(const Tensor& x, const Tensor& y, const Tensor& a) { - return Op("lerp", &x, &y, &a); - } - - static Tensor& smoothstep(const Tensor& a, const Tensor& b, const Tensor& x) { - return Op("smoothstep", &a, &b, &x); - } - - static Tensor& select(const Tensor& cond, const Tensor& x, const Tensor& y) { - return Op("ternary", &cond, &x, &y); - } - - static Tensor& fma(const Tensor& x, const Tensor& y, const Tensor& z) { - return Op("fma", &x, &y, &z); - } - - static Tensors IndexGrid(const Tensors& begin, const Tensors& end) { - //compute shape - Tensors shape = Tensors(); - for (int i = 0; i < begin.size(); i++) { - shape.push_back(&(*end[i] - *begin[i])); - } - //compute indices - Tensors index_grid = Tensors(); - for (int i = 0; i < begin.size(); i++) { - index_grid.push_back(&(Index(shape, i) + *begin[i])); - } - return index_grid; - } - - static Tensors IndexGrid(const Tensors& begin, const Tensors& end, const Tensors& step) - { - Tensors shape = Tensors(); - for (int i = 0; i < begin.size(); i++) { - shape.push_back(&((*end[i] - *begin[i]) / *step[i])); - } - //compute indices - Tensors index_grid = Tensors(); - for (int i = 0; i < begin.size(); i++) { - index_grid.push_back(&(Index(shape, i) * *step[i] + *begin[i])); - } - return index_grid; - } - - static void PrintValue(std::string name, const Tensor& tensor) { - if (tensor.GetDimension() != 0) { - throw std::runtime_error("Cannot print a non-scalar value"); - } - Tensor& output = Op("print_value", &tensor); - output.SetDebugName(name); - } - - static void AssertValue(std::string name, const Tensor& tensor) { - if (tensor.GetDimension() != 0) { - throw std::runtime_error("Cannot assert a non-scalar value"); - } - Tensor& output = Op("assert_value", &tensor); - output.SetDebugName(name); - } - - void SetDebugName(const string& name) const; - - static void BeginRegion(const string& name); - static void EndRegion(const string& name); - - int axis(int i = 0) const { - return (int)node_->data[i]; - } - - uint data(int i = 0) const { - return node_->data[i]; - } -}; - -} // namespace TensorFrost diff --git a/TensorFrost/Tensor/TensorProgram.cpp b/TensorFrost/Tensor/TensorProgram.cpp deleted file mode 100644 index 8d29de7e..00000000 --- a/TensorFrost/Tensor/TensorProgram.cpp +++ /dev/null @@ -1,89 +0,0 @@ -#include -#include -#include "TensorProgram.h" - -namespace TensorFrost { - -void TensorProgram::CreateProgram(string name) { - Tensor::SetEvaluationContext(nullptr); - - //get current time - auto start = std::chrono::high_resolution_clock::now(); - - // create new IR graph - Tensor::SetEvaluationContext(&ir); - Tensor::BeginRegion(name); - Tensors outputs = evaluate_callback(); - Tensor::EndRegion(name); - // set outputs - for (int i = 0; i < outputs.size(); i++) { - outputs[i]->SetMemoryType(NodeProp::OutputMemory, i); - } - ir.output_memory_count = (int)outputs.size(); - - if (outputs.size() == 0) { - throw std::runtime_error("TensorProgram does not do any computation: no outputs"); - } - - program = GenerateProgram(&ir); - program->program_name = name; - - Tensor::SetEvaluationContext(nullptr); - - auto end = std::chrono::high_resolution_clock::now(); - - compile_time = std::chrono::duration(end - start).count(); - - start = std::chrono::high_resolution_clock::now(); - - GenerateCode(program); - - end = std::chrono::high_resolution_clock::now(); - - codegen_time = std::chrono::duration(end - start).count(); - - if (current_backend != BackendType::CodeGen) // no need to compile if we are in codegen mode - { - auto start_time = chrono::high_resolution_clock::now(); - CompileAndLoadKernelModule(program, program_id); - auto end_time = chrono::high_resolution_clock::now(); - host_compile_time = chrono::duration(end_time - start_time).count(); - - start_time = chrono::high_resolution_clock::now(); - CompileKernels(program); - end_time = chrono::high_resolution_clock::now(); - shader_compile_time = chrono::duration(end_time - start_time).count(); - } -} - -vector TensorProgram::Evaluate( - const vector& input) const { - return ExecuteProgram(program, input); -} - -string TensorProgram::PrintProperties() const { - string properties = program->program_name + ":\n"; - int compute_kernels = (int)program->kernels_.size(); - int lines = 0; - string line; - istringstream stream(program->generated_code_); - while (getline(stream, line)) { - lines++; - } - properties += " Kernel count: " + to_string(compute_kernels) + "\n"; - properties += " Intermediate buffers: " + to_string(ir.temp_memory_count) + "\n"; - properties += " Host readbacks: " + to_string(ir.readbacks) + "\n"; - properties += " Host writes: " + to_string(ir.writebacks) + "\n"; - properties += " Lines of generated code: " + to_string(lines) + "\n"; - properties += " IR Compile time: " + to_string(compile_time) + " ms\n"; - properties += " Codegen time: " + to_string(codegen_time) + " ms\n"; - if(host_compile_time > 0.01f) - properties += " Host Compile time: " + to_string(host_compile_time) + " ms\n"; - if (shader_compile_time > 0.01f) - properties += " Shader Compile time: " + to_string(shader_compile_time) + " ms\n"; - return properties; -} - -size_t TensorProgram::program_id = 0; - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Tensor/TensorProgram.h b/TensorFrost/Tensor/TensorProgram.h deleted file mode 100644 index 32abf0d1..00000000 --- a/TensorFrost/Tensor/TensorProgram.h +++ /dev/null @@ -1,46 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "Backend/Backend.h" -#include "Compiler/KernelGen.h" -#include "Tensor.h" - -namespace TensorFrost { - -using namespace std; - -class TensorProgram { - public: - static size_t program_id; - using EvaluateFunction = function; - EvaluateFunction evaluate_callback; - IR ir; - Program* program; - bool debug = false; - float compile_time = 0.0f; - float codegen_time = 0.0f; - float host_compile_time = 0.0f; - float shader_compile_time = 0.0f; - - explicit TensorProgram(EvaluateFunction evaluate, string name) : evaluate_callback(std::move(evaluate)) { - CreateProgram(name); - program_id++; - } - - void CreateProgram(string name); - - vector Evaluate( - const vector& input) const; - - string PrintProperties() const; - - ~TensorProgram() { delete program; } -}; - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/TensorFrost.h b/TensorFrost/TensorFrost.h deleted file mode 100644 index 67d1accb..00000000 --- a/TensorFrost/TensorFrost.h +++ /dev/null @@ -1,10 +0,0 @@ -#pragma once - -#ifdef _RELWITHDEBINFO -#define PYBIND11_DETAILED_ERROR_MESSAGES -#endif - -#include "Backend/Backend.h" -#include "Compiler/KernelGen.h" -#include "Tensor/Tensor.h" -#include "Tensor/TensorProgram.h" diff --git a/TensorFrost/Utility/Utility.cpp b/TensorFrost/Utility/Utility.cpp deleted file mode 100644 index 461ee402..00000000 --- a/TensorFrost/Utility/Utility.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include "Utility.h" - -namespace TensorFrost { - -int GetSize(const vector& shape) { - int size = 1; - for (int i : shape) { - size *= i; - } - return size; -} - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/Utility/Utility.h b/TensorFrost/Utility/Utility.h deleted file mode 100644 index 72d678d4..00000000 --- a/TensorFrost/Utility/Utility.h +++ /dev/null @@ -1,132 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace TensorFrost { -using namespace std; -using uint = unsigned int; - -inline uint AsUint(float f) { return *reinterpret_cast(&f); } -inline uint AsUint(int i) { return *reinterpret_cast(&i); } -inline float AsFloat(uint i) { return *reinterpret_cast(&i); } -inline float AsFloat(int i) { return *reinterpret_cast(&i); } -inline int AsInt(float f) { return *reinterpret_cast(&f); } -inline int AsInt(uint i) { return *reinterpret_cast(&i); } - -int GetSize(const vector& shape); - -template -class FlagSet { - array data; - -public: - FlagSet() { - data.fill(-1); - } - - void set(T flag) { - data[(int)flag] = 0; - } - - template - void set(T flag, Args... args) { - set(flag); - set(args...); - } - - void set(T flag, bool value) { - data[(int)flag] = value ? 0 : -1; - } - - void set(T flag, int64_t value) { - if(value < 0) { - throw std::runtime_error("Flag data must be non-negative"); - } - data[(int)flag] = value + 1; - } - - void remove(T flag) { - data[(int)flag] = -1; - } - - template - void remove(T flag, Args... args) { - remove(flag); - remove(args...); - } - - void clear() { - data.fill(-1); - } - - bool has(T flag) const { - return data[(int)flag] != -1; - } - - template - bool has(T flag, Args... args) const { - return has(flag) && has(args...); - } - - int64_t get(T flag, bool throw_error = true) const { - int64_t res = data[(int)flag] - 1; - if (throw_error && res < 0) { - throw std::runtime_error("Flag data is not set"); - } - return res; - } - - void copy_all_given(const FlagSet& other, unordered_set only) { - for (int i = 0; i < N; i++) { - data[i] = -1; - T flag = (T)i; - if (only.contains(flag)) { - data[(int)flag] = other.data[(int)flag]; - } - } - } - - void copy_all_except(const FlagSet& other, unordered_set except) { - for (int i = 0; i < N; i++) { - data[i] = -1; - T flag = (T)i; - if (!except.contains(flag)) { - data[(int)flag] = other.data[(int)flag]; - } - } - } - - void copy_all(const FlagSet& other) { - for (int i = 0; i < N; i++) { - data[i] = other.data[i]; - } - } - - size_t count() const { - size_t res = 0; - for (int i = 0; i < N; i++) { - if (data[i] != -1) { - res++; - } - } - return res; - } - - unordered_map get_data() const { - unordered_map res; - for (int i = 0; i < N; i++) { - T flag = (T)i; - if (has(flag)) { - res[flag] = data[i] - 1; - } - } - return res; - } -}; - -} // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/include/Definitions/VulkanBindings.h b/TensorFrost/include/Definitions/VulkanBindings.h new file mode 100644 index 00000000..14159ddc --- /dev/null +++ b/TensorFrost/include/Definitions/VulkanBindings.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace TensorFrost { + +void VulkanDefinitions(pybind11::module_& m); + +} // namespace TensorFrost diff --git a/TensorFrost/include/TensorFrost.h b/TensorFrost/include/TensorFrost.h new file mode 100644 index 00000000..25e5c84b --- /dev/null +++ b/TensorFrost/include/TensorFrost.h @@ -0,0 +1,12 @@ +#pragma once + +#include "Compiler/Operation.h" +#include "Compiler/ExecutionContext.h" +#include "Compiler/OperationBlocks.h" +#include "Compiler/OperationArguments.h" +#include "Compiler/Overloads.h" +#include "Compiler/Value.h" +#include "Compiler/Printer.h" +#include "Compiler/TFProgram.h" +#include "Backend/Vulkan.h" +#include "Backend/Window.h" \ No newline at end of file diff --git a/TensorFrost/src/Definitions/PyModule.cpp b/TensorFrost/src/Definitions/PyModule.cpp new file mode 100644 index 00000000..d5100867 --- /dev/null +++ b/TensorFrost/src/Definitions/PyModule.cpp @@ -0,0 +1,62 @@ +// #include +// #include +// +// #include +// #include +// +// namespace TensorFrost { +// +// class PyModule : public Module { +// public: +// using Module::Module; // Inherit constructors +// +// void assert_parameters() override { +// PYBIND11_OVERRIDE(void, Module, assert_parameters); +// } +// +// py::object loss(py::object X, py::object Y) override { +// PYBIND11_OVERRIDE_PURE(py::object, Module, loss, X, Y); +// } +// +// py::object forward(py::object X) override { +// PYBIND11_OVERRIDE_PURE(py::object, Module, forward, X); +// } +// }; +// +// void ModuleDefinitions(py::module& m) { +// py::class_(m, "Parameter") +// .def(py::init&, TFDataFormat, float, float, bool>(), py::arg("shape"), py::arg("dtype") = TFType::Float, py::arg("random_scale") = -1.0f, py::arg("random_offset") = 0.0f, py::arg("optimize") = true) +// .def_readwrite("shape", &Parameter::shape) +// .def_readwrite("dtype", &Parameter::dtype) +// .def_readwrite("random_scale", &Parameter::random_scale) +// .def_readwrite("random_offset", &Parameter::random_offset) +// .def("__repr__", [](const Parameter& p) { +// return "Parameter(shape=" + std::to_string(p.shape.size()) + ", dtype=" + std::to_string(p.dtype.type) + "( " + std::to_string(p.dtype.size) + ") , random_scale=" + std::to_string(p.random_scale) + ", random_offset=" + std::to_string(p.random_offset) + ", optimize=" + std::to_string(p.optimize) + ")"; +// }); +// +// py::class_(m, "ParameterArray") +// .def(py::init<>()) +// .def("__getitem__", &ParameterArray::getitem) +// .def("__setitem__", &ParameterArray::setitem) +// .def("items", &ParameterArray::items); +// +// py::class_(m, "Module") +// .def(py::init(), py::arg("requires_grad") = true) +// .def("__getattr__", &Module::getattr) +// .def("__setattr__", &Module::setattr) +// .def("hasattr", &Module::hasattr) +// .def("param_requires_grad", &Module::param_requires_grad) +// .def("initialize_input", &Module::initialize_input) +// .def("initialize_parameters", &Module::initialize_parameters) +// .def("initialize_parameters_native", &Module::initialize_parameters_native) +// .def("parameters", &Module::parameters) +// .def("named_parameters", &Module::named_parameters) +// .def("requires_grads_list", &Module::requires_grads_list) +// .def("create_input", &Module::create_input) +// .def("update_parameters", &Module::update_parameters) +// .def("assert_parameters", &Module::assert_parameters) +// .def("loss", &Module::loss) +// .def("forward", &Module::forward); +// } +// +// } // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/src/Definitions/PyTensor.cpp b/TensorFrost/src/Definitions/PyTensor.cpp new file mode 100644 index 00000000..624e54e4 --- /dev/null +++ b/TensorFrost/src/Definitions/PyTensor.cpp @@ -0,0 +1,199 @@ +// #include +// #include +// +// #include +// +// namespace TensorFrost { +// +// void DefineOperator( +// const std::string& pyname, py::class_& py_tensor, +// const std::function& op) { +// py_tensor.def(l_op(pyname).c_str(), +// [op](const PyTensor& t, const PyTensor& t2) { +// return PT(op(T(t), T(t2))); +// }); +// py_tensor.def(l_op(pyname).c_str(), [op](const PyTensor& t, const float f) { +// return PT(op(T(t), Tensor::Constant(f))); +// }); +// py_tensor.def(l_op(pyname).c_str(), [op](const PyTensor& t, const int i) { +// return PT(op(T(t), Tensor::Constant(i))); +// }); +// py_tensor.def(r_op(pyname).c_str(), [op](const PyTensor& t, const float f) { +// return PT(op(Tensor::Constant(f), T(t))); +// }); +// py_tensor.def(r_op(pyname).c_str(), [op](const PyTensor& t, const int i) { +// return PT(op(Tensor::Constant(i), T(t))); +// }); +// } +// +// #define LAMBDA_OP(op) \ +// [](const Tensor& t1, const Tensor& t2) -> Tensor& { return t1 op t2; } +// +// void DefineOperators(py::class_& py_tensor) { +// DefineOperator("add", py_tensor, LAMBDA_OP(+)); +// DefineOperator("sub", py_tensor, LAMBDA_OP(-)); +// DefineOperator("mul", py_tensor, LAMBDA_OP(*)); +// DefineOperator("div", py_tensor, LAMBDA_OP(/)); +// DefineOperator("truediv", py_tensor, LAMBDA_OP(/)); +// DefineOperator("mod", py_tensor, LAMBDA_OP(%)); +// DefineOperator("eq", py_tensor, LAMBDA_OP(==)); +// DefineOperator("ne", py_tensor, LAMBDA_OP(!=)); +// DefineOperator("lt", py_tensor, LAMBDA_OP(<)); +// DefineOperator("le", py_tensor, LAMBDA_OP(<=)); +// DefineOperator("gt", py_tensor, LAMBDA_OP(>)); +// DefineOperator("ge", py_tensor, LAMBDA_OP(>=)); +// DefineOperator("and", py_tensor, LAMBDA_OP(&&)); +// DefineOperator("or", py_tensor, LAMBDA_OP(||)); +// DefineOperator("xor", py_tensor, LAMBDA_OP(^)); +// DefineOperator("lshift", py_tensor, LAMBDA_OP(<<)); +// DefineOperator("rshift", py_tensor, LAMBDA_OP(>>)); +// DefineOperator("and_", py_tensor, LAMBDA_OP(&)); +// DefineOperator("or_", py_tensor, LAMBDA_OP(|)); +// +// py_tensor.def("__neg__", [](const PyTensor& t) { return PT(-T(t)); }); +// py_tensor.def("__not__", [](const PyTensor& t) { return PT(!T(t)); }); +// py_tensor.def("__invert__", [](const PyTensor& t) { return PT(~T(t)); }); +// py_tensor.def("__pow__", [](const PyTensor& t, const PyTensor& t2) { +// return PT(Tensor::pow(T(t), T(t2))); +// }); +// py_tensor.def("__pow__", [](const PyTensor& t, float f) { +// return PT(Tensor::pow(T(t), Tensor::Constant(f))); +// }); +// py_tensor.def("__rpow__", [](const PyTensor& t, float f) { +// return PT(Tensor::pow(Tensor::Constant(f), T(t))); +// }); +// py_tensor.def("__matmul__", [](const PyTensor& t, const PyTensor& t2) { +// return PT(Tensor::Matmul(T(t), T(t2))); +// }); +// } +// +// void PyTensorDefinition(py::module& /*m*/, py::class_& py_tensor) { +// // initializers +// py_tensor.def(py::init()); +// py_tensor.def(py::init()); +// py_tensor.def(py::init()); +// py_tensor.def(py::init()); +// +// // properties +// py_tensor.def_property_readonly("shape", [](const PyTensor& t) { +// return PyTensorsFromTensors(Reverse(t.Get().GetShape())); +// }); +// py_tensor.def_property_readonly( +// "type", [](const PyTensor& t) { return t.Get().GetFormat(); }); +// py_tensor.def_property_readonly("indices", [](const PyTensor& t) { +// int dim = T(t).GetDimension(); +// py::tuple indices(dim); +// for (int i = 0; i < dim; i++) { +// indices[i] = PT(T(t).Index(dim - i - 1)); +// } +// return indices; +// }); +// py_tensor.def_property_readonly("op_name", [](const PyTensor& t) { +// return T(t).node_->name; +// }); +// py_tensor.def("try_get_constant", [](const PyTensor& t) { +// if(T(t).node_->name != "const") { +// throw std::runtime_error("Can not get constant from non-constant tensor"); +// } +// return T(t).TryGetConstant(); +// }); +// py_tensor.def("index",[](const PyTensor& t, int dim) { +// int dims = T(t).GetDimension(); +// return PT(T(t).Index(dims - dim - 1)); +// }); +// +// py_tensor.def("block_index", [](const PyTensor& t) { +// return PT(T(t).BlockIndex()); +// }); +// +// py_tensor.def("block_thread_index", [](const PyTensor& t, int block_dim) { +// return PT(T(t).BlockThreadIndex(block_dim)); +// }); +// +// py_tensor.def("detach_grad", [](const PyTensor& t) { +// t.Get().DetachGrad(); +// return t; +// }); +// +// py_tensor.def("pass_grad", [](const PyTensor& t) { +// t.Get().PassGrad(); +// return t; +// }); +// +// py_tensor.def("stop_fusion", [](const PyTensor& t) { +// t.Get().StopFusion(); +// return t; +// }); +// +// py_tensor.def("hint_range", [](const PyTensor& t, py::object min, py::object max) { +// if(t.Get().node_->format == TFTypeFloat32) { +// t.Get().HintRange(py::cast(min), py::cast(max)); +// } else { +// t.Get().HintRange(py::cast(min), py::cast(max)); +// } +// }, py::arg("min"), py::arg("max")); +// +// // operators +// DefineOperators(py_tensor); +// +// //no way to overload normal setter +// //TODO use python AST to generate these functions +// py_tensor.def("set", +// [](const PyTensor& t, const PyTensor& t2) { T(t).Set(T(t2)); }); +// +// py_tensor.def_property("val", [](const PyTensor& t) { return t; }, +// [](PyTensor& t, const PyTensor& val) { T(t).Set(T(val)); }); +// +// // indexing +// py_tensor.def("__getitem__", [](const PyTensor& t, const PyTensor& t1) { +// Tensors indices; +// indices.push_back(&t1.Get()); +// return PyTensor(&t.Get(), indices); +// }); +// py_tensor.def("__getitem__", [](const PyTensor& t, py::tuple indices_tuple) { +// Tensors indices = Reverse(TensorsFromTuple(indices_tuple)); +// return PyTensor(&t.Get(), indices); +// }); +// +// py_tensor.def("__setitem__", +// [](const PyTensor& t, const PyTensor& t1, const PyTensor& t2) { +// Tensors indices; +// indices.push_back(&t1.Get()); +// Tensor::Store(t.Get(), T(t2), indices); +// }); +// py_tensor.def("__setitem__", [](const PyTensor& t, py::tuple indices_tuple, +// const PyTensor& t2) { +// Tensors indices = Reverse(TensorsFromTuple(indices_tuple)); +// Tensor::Store(t.Get(), T(t2), indices); +// }); +// +// py_tensor.def("__setitem__", [](const PyTensor& t, const PyTensor& t1, pybind11::none none) { +// //do nothing +// }); +// py_tensor.def("__setitem__", [](const PyTensor& t, py::tuple indices_tuple, pybind11::none none) { +// //do nothing +// }); +// +// // transpose +// py_tensor.def("transpose", [](const PyTensor& t, int dim1, int dim2) { +// return PT(Tensor::Transpose(T(t), -dim1-1, -dim2-1)); +// }, py::arg("dim1") = -2, py::arg("dim2") = -1, "Transpose the tensor"); +// +// //transpose property +// py_tensor.def_property_readonly("T", [](const PyTensor& t) { +// return PT(Tensor::Transpose(T(t))); +// }); +// +// py_tensor.def("__str__", [](const PyTensor& t) { +// return GetNodeString(t.Get().node_); +// }); +// py_tensor.def("__repr__", [](const PyTensor& t) { +// return GetNodeString(t.Get().node_); +// }); +// +// py_tensor.def("set_debug_name", [](const PyTensor& t, const std::string& name) { +// t.Get().SetDebugName(name); +// }); +// } +// +// } // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/src/Definitions/TensorFunctions.cpp b/TensorFrost/src/Definitions/TensorFunctions.cpp new file mode 100644 index 00000000..6f559ed1 --- /dev/null +++ b/TensorFrost/src/Definitions/TensorFunctions.cpp @@ -0,0 +1,353 @@ +// #include +// #include +// +// #include +// +// namespace TensorFrost { +// +// #define UNARY_FUNCTION(name) \ +// m.def(#name, [](const PyTensor& t) { return PT(Tensor::name(T(t))); }) +// +// #define BINARY_FUNCTION(name) \ +// m.def(#name, [](const PyTensor& t, const PyTensor& t2) { \ +// return PT(Tensor::name(T(t), T(t2))); \ +// }) +// +// #define TERNARY_FUNCTION(name) \ +// m.def(#name, [](const PyTensor& t, const PyTensor& t2, const PyTensor& t3) { \ +// return PT(Tensor::name(T(t), T(t2), T(t3))); \ +// }) +// +// void TensorFunctionsDefinition(py::module& m) { +// UNARY_FUNCTION(copy); +// UNARY_FUNCTION(abs); +// UNARY_FUNCTION(ceil); +// UNARY_FUNCTION(floor); +// UNARY_FUNCTION(round); +// UNARY_FUNCTION(trunc); +// UNARY_FUNCTION(sign); +// UNARY_FUNCTION(frac); +// UNARY_FUNCTION(sin); +// UNARY_FUNCTION(cos); +// UNARY_FUNCTION(tan); +// UNARY_FUNCTION(asin); +// UNARY_FUNCTION(acos); +// UNARY_FUNCTION(atan); +// UNARY_FUNCTION(sinh); +// UNARY_FUNCTION(cosh); +// UNARY_FUNCTION(tanh); +// UNARY_FUNCTION(exp); +// UNARY_FUNCTION(exp2); +// UNARY_FUNCTION(log); +// UNARY_FUNCTION(log2); +// UNARY_FUNCTION(sqrt); +// UNARY_FUNCTION(sqr); +// UNARY_FUNCTION(rsqrt); +// UNARY_FUNCTION(rcp); +// +// UNARY_FUNCTION(pcg); +// UNARY_FUNCTION(pcgf); +// UNARY_FUNCTION(reversebits); +// +// m.def("float", [](const PyTensor& t) { return PT(Tensor::tofloat(T(t))); }); +// m.def("uint", [](const PyTensor& t) { return PT(Tensor::touint(T(t))); }); +// m.def("int", [](const PyTensor& t) { return PT(Tensor::toint(T(t))); }); +// m.def("bool", [](const PyTensor& t) { return PT(Tensor::tobool(T(t))); }); +// +// m.def("asfloat", [](const PyTensor& t) { return PT(Tensor::asfloat(T(t))); }); +// m.def("asuint", [](const PyTensor& t) { return PT(Tensor::asuint(T(t))); }); +// m.def("asint", [](const PyTensor& t) { return PT(Tensor::asint(T(t))); }); +// +// BINARY_FUNCTION(min); +// BINARY_FUNCTION(max); +// BINARY_FUNCTION(pow); +// BINARY_FUNCTION(atan2); +// BINARY_FUNCTION(modf); +// +// BINARY_FUNCTION(grad); +// +// TERNARY_FUNCTION(clamp); +// TERNARY_FUNCTION(fma); +// TERNARY_FUNCTION(lerp); +// TERNARY_FUNCTION(select); +// TERNARY_FUNCTION(smoothstep); +// +// m.def("scatterAdd", [](const PyTensor& t, const PyTensor& t2) { +// Tensor::ScatterAdd(*t.Value(), T(t2), t.Indices()); +// }); +// +// m.def("scatterAddPrev", [](const PyTensor& t, const PyTensor& t2) { +// return PT(Tensor::ScatterAddPrev(*t.Value(), T(t2), t.Indices())); +// }); +// +// m.def("scatterMin", [](const PyTensor& t, const PyTensor& t2) { +// Tensor::ScatterMin(*t.Value(), T(t2), t.Indices()); +// }); +// +// m.def("scatterMax", [](const PyTensor& t, const PyTensor& t2) { +// Tensor::ScatterMax(*t.Value(), T(t2), t.Indices()); +// }); +// +// m.def("scatterOr", [](const PyTensor& t, const PyTensor& t2) { +// Tensor::ScatterOr(*t.Value(), T(t2), t.Indices()); +// }); +// +// m.def("scatterAnd", [](const PyTensor& t, const PyTensor& t2) { +// Tensor::ScatterAnd(*t.Value(), T(t2), t.Indices()); +// }); +// +// m.def("scatterXor", [](const PyTensor& t, const PyTensor& t2) { +// Tensor::ScatterXor(*t.Value(), T(t2), t.Indices()); +// }); +// +// m.def("buffer", [](py::list shape, TFDataFormat type) { +// return PT(Tensor::Memory(Reverse(TensorsFromList(shape)), type)); +// }, py::arg("shape"), py::arg("type") = TFTypeFloat32); +// m.def("buffer", [](std::vector shape, TFDataFormat type) { +// return PT(Tensor::Memory(Reverse(shape), type)); +// }, py::arg("shape"), py::arg("type") = TFTypeFloat32); +// +// m.def("local_buffer", [](int size, TFDataFormat type) { +// return PT(Tensor::LocalMemory(size, type)); +// }, py::arg("size"), py::arg("type") = TFTypeFloat32); +// m.def("group_buffer", [](int size, TFDataFormat type) { +// return PT(Tensor::GroupMemory(size, type)); +// }, py::arg("size"), py::arg("type") = TFTypeFloat32); +// m.def("group_barrier", []() { +// Tensor::GroupBarrier(); +// }); +// +// m.def("zeros", [](py::list shape, TFDataFormat type) { +// return PT(Tensor::Constant(0u, Reverse(TensorsFromList(shape)), type)); +// }, py::arg("shape"), py::arg("type") = TFTypeFloat32); +// +// m.def("const", [](float value, py::list shape) { +// return PT(Tensor::Constant(Reverse(TensorsFromList(shape)), value)); +// }); +// m.def("const", [](float value, std::vector shape) { +// return PT(Tensor::Constant(Reverse(shape), value)); +// }, py::arg("value"), py::arg("shape") = std::vector{}); +// +// m.def("const", [](int value, py::list shape) { +// return PT(Tensor::Constant(Reverse(TensorsFromList(shape)), value)); +// }); +// +// m.def("const", [](int value, std::vector shape) { +// return PT(Tensor::Constant(Reverse(shape), value)); +// }, py::arg("value"), py::arg("shape") = std::vector{}); +// +// m.def("input", [](std::vector shape, TFDataFormat type) { +// return PT(Tensor::Input(Reverse(shape), type)); +// }, py::arg("shape"), py::arg("type") = TFTypeFloat32); +// +// m.def("input", [](py::list shape, TFDataFormat type) { +// return PT(Tensor::Input(Reverse(TensorsFromList(shape)), type)); +// }, py::arg("shape"), py::arg("type") = TFTypeFloat32); +// +// m.def("index", [](int dim, py::list shape) { +// return PT(Tensor::Index(Reverse(TensorsFromList(shape)), dim)); +// }); +// +// m.def("hash", [](py::list shape, const PyTensor& seed) { +// return PT(Tensor::Hash(Reverse(TensorsFromList(shape)), T(seed))); +// }, py::arg("shape"), py::arg("seed")); +// +// m.def("random_value", [](py::list shape, const PyTensor& seed) { +// return PT(Tensor::Random(Reverse(TensorsFromList(shape)), T(seed))); +// }, py::arg("shape"), py::arg("seed")); +// +// m.def("element_index", [](py::list shape) { +// return PT(Tensor::ElementIndex(Reverse(TensorsFromList(shape)))); +// }, py::arg("shape")); +// +// m.def("flat_index", [](py::list shape, py::list indices) { +// Tensors shape_tensors = Reverse(TensorsFromList(shape)); +// Tensors index_tensors = Reverse(TensorsFromList(indices)); +// return PT(Tensor::FlatIndex(shape_tensors, index_tensors)); +// }); +// +// m.def("indices_from_flat_index", [](const PyTensor& index, py::list shape) { +// py::tuple indices = py::tuple(shape.size()); +// Tensors shape_tensors = Reverse(TensorsFromList(shape)); +// Tensors indices_tensors = Reverse(Tensor::IndicesFromFlatIndex(&T(index), shape_tensors)); +// for (int i = 0; i < indices_tensors.size(); i++) { +// indices[i] = PT(*indices_tensors[i]); +// } +// return indices; +// }); +// +// m.def("get_copy", [](const PyTensor& t) { return PT(*Tensor::GetCopy(T(t))); }); +// +// m.def("indices", [](py::list shape) { +// Tensors shape_tensors = Reverse(TensorsFromList(shape)); +// int dim = (int)shape_tensors.size(); +// py::tuple indices = py::tuple(shape_tensors.size()); +// for (int i = 0; i < shape_tensors.size(); i++) { +// auto t = PT(Tensor::Index(shape_tensors, dim - i - 1)); +// indices[i] = t; +// } +// return indices; +// }); +// +// m.def("indices", [](std::vector shape) { +// py::tuple indices = py::tuple(shape.size()); +// int dim = (int)shape.size(); +// for (int i = 0; i < shape.size(); i++) { +// auto t = PT(Tensor::Index(Reverse(shape), dim - i - 1)); +// indices[i] = t; +// } +// return indices; +// }); +// +// +// m.def("index_grid", [](py::list begin, py::list end) { +// Tensors begin_tensors = Reverse(TensorsFromList(begin)); +// Tensors end_tensors = Reverse(TensorsFromList(end)); +// Tensors index_grid = Reverse(Tensor::IndexGrid(begin_tensors, end_tensors)); +// +// py::tuple indices = py::tuple(begin.size()); +// for (int i = 0; i < index_grid.size(); i++) { +// indices[i] = PT(*index_grid[i]); +// } +// return indices; +// }); +// +// m.def("index_grid", [](py::list begin, py::list end, py::list step) { +// Tensors begin_tensors = Reverse(TensorsFromList(begin)); +// Tensors end_tensors = Reverse(TensorsFromList(end)); +// Tensors step_tensors = Reverse(TensorsFromList(step)); +// Tensors index_grid = Reverse(Tensor::IndexGrid(begin_tensors, end_tensors, step_tensors)); +// +// py::tuple indices = py::tuple(begin.size()); +// for (int i = 0; i < index_grid.size(); i++) { +// indices[i] = PT(*index_grid[i]); +// } +// return indices; +// }); +// +// m.def("reshape", [](const PyTensor& t, py::list shape, TFDataFormat type) { +// return PT(Tensor::Reshape(T(t), Reverse(TensorsFromList(shape)), type)); +// }, py::arg("t"), py::arg("shape"), py::arg("type") = TFTypeNone); +// +// m.def("assert_tensor", [](const PyTensor& t, py::list target_shape, TFDataFormat target_type) { +// return PT(Tensor::Assert(T(t), Reverse(TensorsFromList(target_shape)), target_type)); +// }); +// m.def("split_dim", [](const PyTensor& t, const int split_size, const int axis) { +// return PT(Tensor::SplitDim(T(t), split_size, -axis-1)); +// }, py::arg("t"), py::arg("split_size"), py::arg("axis") = -1); +// m.def("merge_dim", [](const PyTensor& t, const int axis, const PyTensor* target_size) { +// const Tensor* target_size_ptr = target_size ? &T(*target_size) : nullptr; +// return PT(Tensor::MergeDim(T(t), -axis-1, target_size_ptr)); +// }, py::arg("t"), py::arg("axis") = -1, py::arg("target_size") = nullptr); +// m.def("repeat", [](const PyTensor& t, const PyTensor& repeats, const int axis) { +// return PT(Tensor::Repeat(T(t), T(repeats), -axis-1)); +// }, py::arg("t"), py::arg("repeats"), py::arg("axis") = -1); +// +// //algorithm functions +// m.def("sum", [](const PyTensor& t, const int axis) { return PT(Tensor::Sum(T(t), -axis-1)); }, +// py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Sum the elements of the tensor along the axis"); +// +// m.def("norm", [](const PyTensor& t, const int axis) { return PT(Tensor::Norm(T(t), -axis-1)); }, +// py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Compute the norm of the tensor along the axis"); +// +// m.def("mean", [](const PyTensor& t, const int axis) { return PT(Tensor::Mean(T(t), -axis-1)); }, +// py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Compute the mean of the tensor along the axis"); +// +// m.def("min", [](const PyTensor& t, const int axis) { return PT(Tensor::Min(T(t), -axis-1)); }, +// py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Compute the min of the tensor along the axis"); +// +// m.def("max", [](const PyTensor& t, const int axis) { return PT(Tensor::Max(T(t), -axis-1)); }, +// py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Compute the max of the tensor along the axis"); +// +// m.def("any", [](const PyTensor& t, const int axis) { return PT(Tensor::Any(T(t), -axis-1)); }, +// py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Do an OR operation along the axis"); +// +// m.def("all", [](const PyTensor& t, const int axis) { return PT(Tensor::All(T(t), -axis-1)); }, +// py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Do an AND operation along the axis"); +// +// m.def("prefix_sum", [](const PyTensor& t, const int axis) { return PT(Tensor::PrefixSum(T(t), -axis-1)); }, +// py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Compute the prefix sum of the tensor along the axis"); +// +// m.def("reverse", [](const PyTensor& t, const int axis) { return PT(Tensor::Reverse(T(t), -axis-1)); }, +// py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Reverse the tensor along the axis"); +// +// m.def("transpose", [](const PyTensor& t, int dim1, int dim2) { +// return PT(Tensor::Transpose(T(t), -dim1-1, -dim2-1)); +// }, py::arg("t"), py::kw_only(), py::arg("dim1") = -2, py::arg("dim2") = -1, "Transpose the tensor"); +// +// m.def("unsqueeze", [](const PyTensor& t, int dim) { +// return PT(Tensor::Unsqueeze(T(t), -dim-1)); +// }, py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Unsqueeze the tensor"); +// +// m.def("squeeze", [](const PyTensor& t, int dim) { +// return PT(Tensor::Squeeze(T(t), -dim-1)); +// }, py::arg("t"), py::kw_only(), py::arg("axis") = -1, "Squeeze the tensor"); +// +// m.def("dot", [](const PyTensor& t, const PyTensor& t2, int axis) { +// return PT(Tensor::Dot(T(t), T(t2), -axis-1)); +// }, py::arg("t"), py::arg("t2"), py::kw_only(), py::arg("axis") = -1, "Dot product of two tensors"); +// +// m.def("matmul", [](const PyTensor& t, const PyTensor& t2) { +// return PT(Tensor::Matmul(T(t), T(t2))); +// }, py::arg("t"), py::arg("t2"), "Matrix multiplication of two tensors"); +// +// m.def("region_begin", [](const std::string& name) { +// Tensor::BeginRegion(name); +// }, py::arg("name"), "Begin a debug region"); +// +// m.def("region_end", [](const std::string& name) { +// Tensor::EndRegion(name); +// }, py::arg("name"), "End a debug region"); +// +// m.def("register_custom_operation", [](const std::string& name, vector overloads, py::function impl, py::function vjp) { +// auto cpp_impl = [impl](Tensors& output, map inputs, const Tensor* tensor, vector axes) { +// py::list input_list; +// for (auto& [id, tensor] : inputs) { +// input_list.append(PT(*tensor)); +// } +// py::list output_list = impl(input_list, PT(*tensor), axes).cast(); +// for (int i = 0; i < output_list.size(); i++) { +// PyTensor* t = output_list[i].cast(); +// output.push_back(&t->Get()); +// } +// }; +// +// auto cpp_vjp = [vjp](map inputs, const Tensor* gradient, const Tensor* tensor) { +// py::list input_list; +// for (auto& [id, tensor] : inputs) { +// input_list.append(PT(*tensor)); +// } +// py::list output_list = vjp(input_list, PT(*gradient), PT(*tensor)).cast(); +// Tensors gradients; +// for (int i = 0; i < output_list.size(); i++) { +// PyTensor* t = output_list[i].cast(); +// gradients.push_back(&t->Get()); +// } +// return gradients; +// }; +// +// RegisterAlgorithmicPrimitive(name, overloads, cpp_impl, cpp_vjp); +// }, py::arg("name"), py::arg("overloads"), py::arg("impl"), py::arg("vjp"), "Register a custom operation"); +// +// m.def("custom", [](const std::string& name, py::list inputs, py::list shape) { +// Tensors input_tensors = TensorsFromList(inputs); +// Tensors shape_tensors = Reverse(TensorsFromList(shape)); +// return PT(Tensor::CustomOperation(name, input_tensors, shape_tensors)); +// }, py::arg("name"), py::arg("inputs"), py::arg("shape"), "Run custom operation"); +// +// m.def("custom", [](const std::string& name, py::list inputs) { +// Tensors input_tensors = TensorsFromList(inputs); +// Tensors shape_tensors = input_tensors[0]->GetShape(); +// return PT(Tensor::CustomOperation(name, input_tensors, shape_tensors)); +// }, py::arg("name"), py::arg("inputs"), "Run custom operation"); +// +// m.def("print_value", [](const std::string& name, const PyTensor& t) { +// Tensor::PrintValue(name, T(t)); +// }, py::arg("name"), py::arg("t"), "Print the value of the tensor"); +// +// m.def("assert_value", [](const std::string& name, const PyTensor& t) { +// Tensor::AssertValue(name, T(t)); +// }, py::arg("name"), py::arg("t"), "Assert the value of the tensor"); +// } +// +// } // namespace TensorFrost diff --git a/TensorFrost/src/Definitions/TensorMemory.cpp b/TensorFrost/src/Definitions/TensorMemory.cpp new file mode 100644 index 00000000..f6186b49 --- /dev/null +++ b/TensorFrost/src/Definitions/TensorMemory.cpp @@ -0,0 +1,72 @@ +// #include +// #include +// +// #include +// #include +// +// namespace TensorFrost { +// +// void TensorMemoryDefinition(py::module& m, +// py::class_& py_tensor_mem) { +// //define constructors from numpy arrays +// py_tensor_mem.def(py::init([](py::array arr) { +// return PyTensorMemory(arr); +// }), "Create a TensorMemory from a numpy array", py::return_value_policy::take_ownership); +// +// // "constructor" +// m.def( +// "tensor", +// [](const std::vector& shape, TFDataFormat type) { +// return PyTensorMemory(shape, type); +// },"Create a TensorMemory with the given shape", py::return_value_policy::take_ownership); +// +// // "constructor" from numpy array +// m.def( +// "tensor", +// [](py::array arr) { +// return new PyTensorMemory(arr); +// }, +// "Create a TensorMemory from a numpy array", py::return_value_policy::take_ownership); +// +// // properties +// py_tensor_mem.def_property_readonly("shape", [](const PyTensorMemory& t) { +// vector shape = t.Shape(); +// return py::cast(shape); +// }); +// +// py_tensor_mem.def_property_readonly("type", [](const PyTensorMemory& t) { +// return t.GetFormat(); +// }); +// +// py_tensor_mem.def_property_readonly("size", [](const PyTensorMemory& t) { +// return GetSize(t.tensor_); +// }); +// +// // to numpy array +// py_tensor_mem.def_property_readonly( +// "numpy", +// [](const PyTensorMemory& t) +// -> std::variant, py::array_t, +// py::array_t, py::array_t> { +// if (t.GetFormat() == TFTypeFloat32) { +// return t.ToPyArray(); +// } else if (t.GetFormat() == TFTypeInt32) { +// return t.ToPyArray(); +// } else if (t.GetFormat() == TFTypeUint32) { +// return t.ToPyArray(); +// } else if (t.GetFormat() == TFTypeBool32) { +// return t.ToPyArray(); +// } else { +// throw std::runtime_error("Unsupported data type for numpy conversion"); +// } +// }, +// "Readback data from tensor memory to a numpy array", py::return_value_policy::take_ownership); +// +// m.def("allocated_memory", []() { return global_memory_manager->GetAllocatedSize(); }, +// "Get the amount of memory currently used by the memory manager"); +// +// m.def("unused_allocated_memory", []() { return global_memory_manager->GetUnusedAllocatedSize(); }, +// "Get the amount of memory currently allocated but not used by the memory manager"); +// } +// +// } // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/src/Definitions/TensorProgram.cpp b/TensorFrost/src/Definitions/TensorProgram.cpp new file mode 100644 index 00000000..dd1cfcf5 --- /dev/null +++ b/TensorFrost/src/Definitions/TensorProgram.cpp @@ -0,0 +1,199 @@ +// #include +// #include +// +// #include +// #include +// #include +// +// namespace TensorFrost { +// +// void TensorProgramDefinition(py::module& m, +// py::class_& tensor_program) { +// m.def( +// "compile", +// [](const py::function& py_evaluate) { +// // Extract the name of the Python function +// std::string func_name = +// py_evaluate.attr("__name__").cast(); +// +// vector inputs = GetFunctionArguments(py_evaluate); +// vector arg_names; +// vector arg_props; +// for (auto arg : inputs) { +// arg_names.push_back(std::get<0>(arg)); +// py::object arg_prop = std::get<1>(arg); +// py::object arg_default = std::get<2>(arg); +// if (py::isinstance(arg_prop)) { +// PyTensorArg arg_tensor = arg_prop.cast(); +// arg_props.push_back(arg_tensor); +// } else { +// throw std::runtime_error("Unsupported input type " + std::string(py::str(arg_prop))); +// } +// } +// +// TensorProgram& program = *new TensorProgram( +// [py_evaluate, arg_names, arg_props]() -> Tensors { +// py::gil_scoped_acquire acquire; +// std::vector args; +// //create inputs from the arguments +// for (size_t i = 0; i < arg_names.size(); i++) { +// Tensor& input = Tensor::Input(arg_props[i].shape, arg_props[i].type); +// input.SetDebugName(arg_names[i]); +// PyTensor* py_tensor = new PyTensor(&input); +// args.push_back(py_tensor); +// } +// //convert to py::args +// py::args py_args = py::cast(args); +// py::object result = py_evaluate(*py_args); +// Tensors outputs; +// //if the result is a single tensor +// if (py::isinstance(result)) { +// outputs.push_back(&py::cast(result).Get()); +// } else { +// auto py_outputs = py::cast>(result); +// for (PyTensor output : py_outputs) { +// outputs.push_back(&output.Get()); +// } +// } +// return outputs; +// }, +// func_name); +// +// py::print(program.PrintProperties()); +// return &program; +// }, +// "Compile a TensorProgram from a python function"); +// +// tensor_program.def( +// "__call__", +// [](TensorProgram& program, py::args py_inputs) -> std::variant { +// vector inputs_props; +// vector temp_numpy_tensors; +// for (auto arg : py_inputs) { +// if (py::isinstance(arg)) { //if just tensor memory +// PyTensorMemory* mem = &arg.cast(); +// inputs_props.push_back(arg.cast()); +// } else if (py::isinstance(arg)) { //if module then add its parameters +// Module* module = &arg.cast(); +// py::list params = module->parameters(); +// for (auto param : params) { +// inputs_props.push_back(param.cast()); +// } +// } else if (py::isinstance(arg)) { //if numpy array then create pytensormemory from it and add it +// py::array arr = arg.cast(); +// PyTensorMemory* temp_tensor = new PyTensorMemory(arr); +// inputs_props.push_back(py::cast(temp_tensor, py::return_value_policy::take_ownership)); +// temp_numpy_tensors.push_back(temp_tensor->tensor_); +// } else if (py::isinstance(arg)) { //if list then convert to py::array then create pytensormemory from it and add it +// py::array arr = ListToArray(arg.cast()); +// PyTensorMemory* temp_tensor = new PyTensorMemory(arr); +// inputs_props.push_back(py::cast(temp_tensor, py::return_value_policy::take_ownership)); +// temp_numpy_tensors.push_back(temp_tensor->tensor_); +// } else { +// throw std::runtime_error("Unsupported input type " + std::string(py::str(arg))); +// } +// } +// +// vector inputs; +// for (auto input : inputs_props) { +// PyTensorMemory* mem = input.cast(); +// inputs.push_back(mem->tensor_); +// } +// vector outputs = program.Evaluate(inputs); +// +// //remove temporary tensors if they are not in the outputs +// for (TFTensor* temp_tensor : temp_numpy_tensors) { +// bool found = false; +// for (TFTensor* output : outputs) { +// if (temp_tensor->buffer == output->buffer) { +// found = true; +// break; +// } +// } +// if (!found) { +// global_memory_manager->DeallocateTensor(*temp_tensor); +// } +// } +// +// vector output_tensors; +// for (size_t i = 0; i < outputs.size(); i++) { +// //if any of the outputs are also inputs, then replace them with the input tensors +// TFTensor* out = outputs[i]; +// bool is_input = false; +// for (size_t j = 0; j < inputs_props.size(); j++) { +// PyTensorMemory* in = inputs_props[j].cast(); +// if (out->buffer == in->tensor_->buffer) { +// output_tensors.push_back(inputs_props[j]); +// is_input = true; +// break; +// } +// } +// if (is_input) { +// continue; +// } +// //otherwise create a new tensor memory +// output_tensors.push_back(py::cast(new PyTensorMemory(outputs[i]), py::return_value_policy::take_ownership)); +// } +// +// //if there is only one output, return the tensor memory +// if (outputs.size() == 1) { +// return output_tensors[0]; +// } else { +// //convert to py::tuple of PyTensorMemory* +// py::tuple py_outputs = py::tuple(outputs.size()); +// for (size_t i = 0; i < outputs.size(); i++) { +// py_outputs[i] = output_tensors[i]; +// } +// return py_outputs; +// } +// }, +// "Evaluate the TensorProgram with the given inputs"); +// +// tensor_program.def( +// "list_operations", +// [](TensorProgram& program, bool compact) { +// std::string listing = "List of operations:\n"; +// listing += GetOperationListing(program.ir, compact); +// return py::str(listing); +// }, +// py::arg("compact") = true); +// +// tensor_program.def("compiled_code", [](TensorProgram& program) { +// string code = program.program->generated_code_; +// return py::str(code); +// }); +// +// tensor_program.def("get_kernels", [](TensorProgram& program) { +// vector kernel_source; +// for (auto& kernel : program.program->kernels_) { +// kernel_source.push_back(kernel.full_generated_code_); +// } +// return kernel_source; +// }); +// +// tensor_program.def("get_main_function", [](TensorProgram& program) { +// return program.program->main_function_; +// }); +// +// tensor_program.def("get_last_execution_time", [](TensorProgram& program) { +// return program.program->last_execution_time; +// }); +// +// m.def("get_all_generated_main_functions", []() { +// return global_kernel_manager->GetAllMainFunctions(); +// }); +// +// m.def("get_all_generated_kernels", []() { +// return global_kernel_manager->GetAllKernels(); +// }); +// +// m.def("get_cpp_header", []() { +// return GetCPPHeader(); +// }); +// +// m.def("get_cpp_implementation", []() { +// return GetCPPImplementation(); +// }); +// } +// +// } // namespace TensorFrost diff --git a/TensorFrost/src/Definitions/TensorScope.cpp b/TensorFrost/src/Definitions/TensorScope.cpp new file mode 100644 index 00000000..07bac335 --- /dev/null +++ b/TensorFrost/src/Definitions/TensorScope.cpp @@ -0,0 +1,118 @@ +// #include +// #include +// +// #include +// +// namespace TensorFrost { +// +// void ScopeDefinitions(py::module& m, py::class_& py_tensor) { +// m.def( +// "loop", +// [](const py::function& body, const PyTensor& begin, const PyTensor& end, +// const PyTensor& step) { +// // wrap the function to convert the PyTensor to Tensor +// std::function f2 = [&body](const Tensor& t) { +// py::gil_scoped_acquire acquire; +// body(PT(t)); +// }; +// +// Tensor::Loop(T(begin), T(end), T(step), f2); +// }, +// py::arg("begin") = 0, py::arg("end"), py::arg("step") = 1, +// py::arg("body")); +// +// m.def( +// "if_cond", +// [](const PyTensor& condition, const py::function& true_body) { +// std::function f = [&true_body]() { +// py::gil_scoped_acquire acquire; +// true_body(); +// }; +// Tensor::If(T(condition), f); +// }, +// py::arg("condition"), py::arg("true_body")); +// +// m.def( +// "if_cond", +// [](const PyTensor& condition, const py::function& true_body, +// const py::function& false_body) { +// std::function f1 = [&true_body]() { +// py::gil_scoped_acquire acquire; +// true_body(); +// }; +// std::function f2 = [&false_body]() { +// py::gil_scoped_acquire acquire; +// false_body(); +// }; +// Tensor::If(T(condition), f1, f2); +// }, +// py::arg("condition"), py::arg("true_body"), py::arg("false_body")); +// +// m.def("break_loop", []() { Tensor::Break(); }); +// m.def("continue_loop", []() { Tensor::Continue(); }); +// +// m.def( +// "kernel", +// [](py::list shape, const py::function& body) { +// // wrap the function to convert the PyTensor to Tensor +// std::function f2 = +// [&body](const Tensors& tensors) { +// py::gil_scoped_acquire acquire; +// PyTensors py_tensors = PyTensorsFromTensors(tensors); +// body(py_tensors); +// }; +// +// Tensor::Kernel(Reverse(TensorsFromList(shape)), f2); +// }, +// py::arg("shape"), py::arg("body")); +// +// // m.def( +// // "vmap", +// // [](py::list inputs, py::list shape, const py::function& func) { +// // std::function f = [&func]() { +// // py::gil_scoped_acquire acquire; +// // func(); +// // }; +// // Tensor::If(T(condition), f); +// // }, +// // py::arg("condition"), py::arg("true_body")); +// +// py_tensor.def("__enter__", &PyTensor::__enter__); +// py_tensor.def("__exit__", &PyTensor::__exit__); +// +// //loop scope +// m.def("loop", +// [](const PyTensor& begin, const PyTensor& end, const PyTensor& step) { +// Tensor& for_loop = Tensor::Loop(T(begin), T(end), T(step)); +// return PT(for_loop); +// }); +// +// m.def("loop", +// [](const PyTensor& begin, const PyTensor& end) { +// Tensor& for_loop = Tensor::Loop(T(begin), T(end), T(PyTensor(1))); +// return PT(for_loop); +// }); +// +// m.def("loop", +// [](const PyTensor& end) { +// Tensor& for_loop = Tensor::Loop(T(PyTensor(0)), T(end), T(PyTensor(1))); +// return PT(for_loop); +// }); +// +// //if scope +// m.def("if_cond", +// [](const PyTensor& condition) { +// Tensor& if_cond = Tensor::If(T(condition)); +// return PT(if_cond); +// }); +// +// //kernel scope +// m.def("kernel", +// [](py::list shape, vector group_size) { +// Tensors shape_tensors = Reverse(TensorsFromList(shape)); +// Tensor& kernel = Tensor::Kernel(shape_tensors, group_size); +// return PT(kernel); +// }, py::arg("shape"), py::arg("group_size") = vector()); +// } +// +// } // namespace TensorFrost \ No newline at end of file diff --git a/TensorFrost/src/Definitions/VulkanBindings.cpp b/TensorFrost/src/Definitions/VulkanBindings.cpp new file mode 100644 index 00000000..cb30e428 --- /dev/null +++ b/TensorFrost/src/Definitions/VulkanBindings.cpp @@ -0,0 +1,244 @@ +#include "Definitions/VulkanBindings.h" +#include "VulkanInterface.h" + +#include +#include + +#include +#include +#include + +namespace py = pybind11; + +namespace TensorFrost { + +void VulkanDefinitions(py::module_& m) { + py::class_(m, "Buffer", "Vulkan-backed storage buffer exposed to Python.") + .def(py::init(), + py::arg("count"), py::arg("dtype_size"), py::arg("read_only") = false, + "Create a buffer sized for `count` elements of size `dtype_size`.") + .def_property_readonly("size", &PyBuffer::byteSize, "Total size of the buffer in bytes.") + .def_property_readonly("count", &PyBuffer::elementCapacity, + "Maximum number of elements the buffer can hold for the configured dtype size.") + .def_property_readonly("read_only", &PyBuffer::isReadOnly, + "Whether the buffer is flagged as read-only for compute kernels.") + .def("setData", &PyBuffer::setData, py::arg("data"), py::arg("offset") = 0, + "Upload data from a NumPy array or bytes-like object into the buffer.") + .def("getData", + [](const PyBuffer& self, const py::object& dtype, const py::object& count, size_t offset) { + return self.getData(dtype, count, offset); + }, + py::arg("dtype") = py::none(), py::arg("count") = py::none(), py::arg("offset") = 0, + "Download data from the buffer into a newly allocated NumPy array.") + .def("release", &PyBuffer::release, + "Explicitly destroy the underlying Vulkan buffer and release its memory."); + + m.def("createBuffer", + [](size_t count, size_t dtypeSize, bool readOnly) { + return PyBuffer(count, dtypeSize, readOnly); + }, + py::arg("count"), py::arg("dtype_size"), py::arg("read_only") = false, + py::return_value_policy::move, + "Convenience helper to construct a :class:`Buffer` without calling the class directly."); + + py::class_(m, "ComputeProgram", + "Compiled compute pipeline that can be dispatched on the GPU.") + .def_property_readonly("readonly_count", &PyComputeProgram::readonlyCount, + "Number of read-only storage buffers expected by the program.") + .def_property_readonly("readwrite_count", &PyComputeProgram::readwriteCount, + "Number of read-write storage buffers expected by the program.") + .def_property_readonly("push_constant_size", &PyComputeProgram::pushConstantSize, + "Size in bytes of the push-constant block expected by this program (0 if unused).") + .def("run", &PyComputeProgram::run, + py::arg("readonly_buffers"), + py::arg("readwrite_buffers"), + py::arg("group_count"), + py::arg("push_constants") = py::none(), + "Dispatch the compute pipeline with the provided buffers, workgroup count, and optional push constants.") + .def("release", &PyComputeProgram::release, + "Explicitly destroy the underlying Vulkan pipeline and associated resources."); + + m.def("createComputeProgramFromSlang", + [](const std::string& moduleName, + const std::string& source, + const std::string& entry, + uint32_t roCount, + uint32_t rwCount, + uint32_t pushConstantSize) { + return MakeComputeProgramFromSlang( + moduleName, source, entry, roCount, rwCount, pushConstantSize); + }, + py::arg("module_name"), py::arg("source"), py::arg("entry"), + py::arg("ro_count"), py::arg("rw_count"), + py::arg("push_constant_size") = 0, + py::return_value_policy::move, + "Compile a Slang module to SPIR-V and wrap it in a :class:`ComputeProgram`."); + + py::class_(m, "Window", + "GLFW-backed Vulkan swapchain window for presenting compute output.") + .def(py::init(), py::arg("width"), py::arg("height"), py::arg("title"), + "Create a window with an attached Vulkan swapchain.") + .def_property_readonly("size", &PyWindow::size, + "Current window extent as a tuple ``(width, height)``.") + .def_property_readonly("format", &PyWindow::format, + "Pixel format of the swapchain image as a Vulkan enum value.") + .def("isOpen", &PyWindow::isOpen, + "Return ``True`` while the window is alive and the user has not closed it.") + .def("drawBuffer", &PyWindow::drawBuffer, + py::arg("buffer"), py::arg("width"), py::arg("height"), py::arg("offset") = 0, + "Copy a buffer of packed pixels onto the swapchain.") + .def("present", &PyWindow::present, + "Present the current frame without uploading new pixels.") + .def("close", &PyWindow::close, + "Destroy the window and release its swapchain resources.") + .def("imgui_begin", &PyWindow::imguiBegin, + py::arg("name"), py::arg("open") = py::none(), py::arg("flags") = 0, + "Begin a new ImGui window, returning (visible, open_flag_or_None).") + .def("imgui_end", &PyWindow::imguiEnd, + "End the current ImGui window.") + .def("imgui_text", &PyWindow::imguiText, + py::arg("text"), + "Add text to the current ImGui window.") + .def("imgui_button", &PyWindow::imguiButton, + py::arg("label"), + "Add a button and return True when pressed.") + .def("imgui_checkbox", &PyWindow::imguiCheckbox, + py::arg("label"), py::arg("value"), + "Add a checkbox and return the updated value.") + .def("imgui_slider_int", &PyWindow::imguiSliderInt, + py::arg("label"), py::arg("value"), py::arg("min"), py::arg("max"), + "Slider that returns the updated integer value.") + .def("imgui_slider_float", &PyWindow::imguiSliderFloat, + py::arg("label"), py::arg("value"), py::arg("min"), py::arg("max"), + "Slider that returns the updated float value.") + .def("imgui_plot_lines", &PyWindow::imguiPlotLines, + py::arg("label"), py::arg("values"), py::arg("values_offset") = 0, + py::arg("overlay_text") = "", py::arg("scale_min") = FLT_MAX, + py::arg("scale_max") = FLT_MAX, + py::arg("graph_size") = py::make_tuple(0.0f, 0.0f), + py::arg("stride") = sizeof(float), + "Plot a sequence of values as lines.") + .def("imgui_scale_all_sizes", &PyWindow::imguiScaleAllSizes, + py::arg("scale"), + "Scale all ImGui sizes by a factor.") + .def("imgui_add_background_text", &PyWindow::imguiAddBackgroundText, + py::arg("text"), py::arg("pos"), py::arg("color"), + "Draw text in the background draw list.") + .def("imgui_same_line", &PyWindow::imguiSameLine, + py::arg("offset_from_start_x") = 0.0f, py::arg("spacing") = -1.0f, + "Place the next item on the same horizontal line.") + .def("imgui_separator", &PyWindow::imguiSeparator, + "Insert a separator line between items.") + .def("imgui_spacing", &PyWindow::imguiSpacing, + "Insert vertical spacing between items.") + .def("imgui_indent", &PyWindow::imguiIndent, + py::arg("indent_w") = 0.0f, + "Increase the current horizontal indent.") + .def("imgui_unindent", &PyWindow::imguiUnindent, + py::arg("indent_w") = 0.0f, + "Decrease the current horizontal indent.") + .def("imgui_begin_child", &PyWindow::imguiBeginChild, + py::arg("id"), py::arg("size") = py::none(), py::arg("border") = false, py::arg("flags") = 0, + "Begin a child region and return True if visible. Always pair with :meth:`imgui_end_child` even when the return value is False.") + .def("imgui_end_child", &PyWindow::imguiEndChild, + "End the current child region.") + .def("imgui_text_wrapped", &PyWindow::imguiTextWrapped, + py::arg("text"), + "Render wrapped text within the current column width.") + .def("imgui_text_colored", &PyWindow::imguiTextColored, + py::arg("color"), py::arg("text"), + "Render text with the given RGBA color.") + .def("imgui_bullet_text", &PyWindow::imguiBulletText, + py::arg("text"), + "Render text preceded by a bullet.") + .def("imgui_input_text", &PyWindow::imguiInputText, + py::arg("label"), py::arg("value"), py::arg("buffer_length") = 0, py::arg("flags") = 0, + "Input text returning (modified, value).") + .def("imgui_input_int", &PyWindow::imguiInputInt, + py::arg("label"), py::arg("value"), py::arg("step") = 1, py::arg("step_fast") = 100, py::arg("flags") = 0, + "Integer input returning the updated value.") + .def("imgui_input_float", &PyWindow::imguiInputFloat, + py::arg("label"), py::arg("value"), py::arg("step") = 0.0f, py::arg("step_fast") = 0.0f, + py::arg("format") = "%.3f", py::arg("flags") = 0, + "Float input returning the updated value.") + .def("imgui_color_edit3", &PyWindow::imguiColorEdit3, + py::arg("label"), py::arg("color"), py::arg("flags") = 0, + "Color editor returning (modified, rgb tuple).") + .def("imgui_color_edit4", &PyWindow::imguiColorEdit4, + py::arg("label"), py::arg("color"), py::arg("flags") = 0, + "Color editor returning (modified, rgba tuple).") + .def("imgui_begin_main_menu_bar", &PyWindow::imguiBeginMainMenuBar, + "Begin the global main menu bar, returning True if it is visible.") + .def("imgui_end_main_menu_bar", &PyWindow::imguiEndMainMenuBar, + "End the global main menu bar.") + .def("imgui_begin_menu_bar", &PyWindow::imguiBeginMenuBar, + "Begin a menu bar on the current window, returning True if visible.") + .def("imgui_end_menu_bar", &PyWindow::imguiEndMenuBar, + "End the current menu bar.") + .def("imgui_begin_menu", &PyWindow::imguiBeginMenu, + py::arg("label"), py::arg("enabled") = true, + "Begin a menu entry and return True if it is open.") + .def("imgui_end_menu", &PyWindow::imguiEndMenu, + "End the current menu entry.") + .def("imgui_menu_item", &PyWindow::imguiMenuItem, + py::arg("label"), py::arg("shortcut") = py::none(), py::arg("selected") = false, py::arg("enabled") = true, + "Create a menu item and return True when activated.") + .def("imgui_open_popup", &PyWindow::imguiOpenPopup, + py::arg("id"), py::arg("popup_flags") = 0, + "Open a popup window by identifier.") + .def("imgui_begin_popup", &PyWindow::imguiBeginPopup, + py::arg("id"), py::arg("flags") = 0, + "Begin a popup window, returning True if it is open.") + .def("imgui_begin_popup_modal", &PyWindow::imguiBeginPopupModal, + py::arg("name"), py::arg("open") = py::none(), py::arg("flags") = 0, + "Begin a modal popup, returning (visible, open_flag_or_None).") + .def("imgui_end_popup", &PyWindow::imguiEndPopup, + "End the current popup window.") + .def("imgui_close_current_popup", &PyWindow::imguiCloseCurrentPopup, + "Close the current popup window.") + .def("imgui_push_style_color", &PyWindow::imguiPushStyleColor, + py::arg("index"), py::arg("color"), + "Push a style color onto the stack.") + .def("imgui_pop_style_color", &PyWindow::imguiPopStyleColor, + py::arg("count") = 1, + "Pop style colors from the stack.") + .def("imgui_push_style_var_float", &PyWindow::imguiPushStyleVarFloat, + py::arg("index"), py::arg("value"), + "Push a float style variable onto the stack.") + .def("imgui_push_style_var_vec2", &PyWindow::imguiPushStyleVarVec2, + py::arg("index"), py::arg("value"), + "Push a 2D vector style variable onto the stack.") + .def("imgui_pop_style_var", &PyWindow::imguiPopStyleVar, + py::arg("count") = 1, + "Pop style variables from the stack.") + .def("imgui_get_font_global_scale", &PyWindow::imguiGetFontGlobalScale, + "Get the global font scale factor.") + .def("imgui_set_font_global_scale", &PyWindow::imguiSetFontGlobalScale, + py::arg("scale"), + "Set the global font scale factor.") + .def("imgui_get_style_color_vec4", &PyWindow::imguiGetStyleColorVec4, + py::arg("index"), + "Get a style color as an RGBA tuple.") + .def("imgui_set_style_color_vec4", &PyWindow::imguiSetStyleColorVec4, + py::arg("index"), py::arg("color"), + "Set a style color from an RGBA tuple.") + .def("mouse_position", &PyWindow::mousePosition, + "Get the current mouse position in window coordinates.") + .def("is_mouse_button_pressed", &PyWindow::isMouseButtonPressed, + py::arg("button"), + "Return True if the specified mouse button is pressed.") + .def("imgui_want_capture_mouse", &PyWindow::imguiWantCaptureMouse, + "Return True if ImGui wants to capture mouse input this frame.") + .def("consume_scroll_delta", &PyWindow::consumeScrollDelta, + "Consume the accumulated scroll delta as a tuple ``(x, y)`` and reset it to zero."); + + m.def("createWindow", + [](int width, int height, const std::string& title) { + return PyWindow(width, height, title); + }, + py::arg("width"), py::arg("height"), py::arg("title"), + py::return_value_policy::move, + "Convenience helper to construct a :class:`Window` without calling the class directly."); +} + +} // namespace TensorFrost diff --git a/TensorFrost/src/Definitions/VulkanInterface.cpp b/TensorFrost/src/Definitions/VulkanInterface.cpp new file mode 100644 index 00000000..9cab1364 --- /dev/null +++ b/TensorFrost/src/Definitions/VulkanInterface.cpp @@ -0,0 +1,752 @@ +#include "VulkanInterface.h" + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include "Backend/Vulkan.h" +#include "Backend/Window.h" + +namespace py = pybind11; + +namespace TensorFrost { +namespace { + +bool isCContiguous(const py::buffer_info& info) { + py::ssize_t stride = info.itemsize; + for (py::ssize_t d = info.ndim - 1; d >= 0; --d) { + if (info.strides[d] != stride) return false; + stride *= info.shape[d]; + } + return true; +} + +struct PushConstantPayload { + const void* data = nullptr; + size_t size = 0; + std::vector storage; +}; + +PushConstantPayload preparePushConstantPayload(const ComputeProgram& program, + const py::object& pushConstants) { + PushConstantPayload payload; + if (program.pushConstantSize == 0) { + if (!pushConstants.is_none()) { + throw std::runtime_error("program does not declare push constants"); + } + return payload; + } + if (pushConstants.is_none()) { + throw std::runtime_error("push constant payload required for this program"); + } + + py::buffer buffer(pushConstants); + auto info = buffer.request(); + if (!isCContiguous(info)) { + throw std::runtime_error("push constant payload must be C-contiguous"); + } + size_t totalBytes = static_cast(info.size) * static_cast(info.itemsize); + if (totalBytes != program.pushConstantSize) { + throw std::runtime_error( + "push constant payload size mismatch (got " + std::to_string(totalBytes) + + ", expected " + std::to_string(program.pushConstantSize) + ")"); + } + + payload.storage.resize(totalBytes); + std::memcpy(payload.storage.data(), info.ptr, totalBytes); + payload.data = payload.storage.data(); + payload.size = totalBytes; + return payload; +} + +} // namespace + +PyBuffer::PyBuffer(size_t count, size_t dtypeSize, bool readOnly) + : ctx_(&getVulkanContext()), + buffer_(createBuffer(count, dtypeSize, readOnly)), + readOnly_(readOnly), + dtypeSizeHint_(dtypeSize ? dtypeSize : 1), + lastCount_(count), + lastDtype_(py::none()) {} + +PyBuffer::~PyBuffer() { release(); } + +PyBuffer::PyBuffer(PyBuffer&& other) noexcept { moveFrom(std::move(other)); } + +PyBuffer& PyBuffer::operator=(PyBuffer&& other) noexcept { + if (this != &other) { + release(); + moveFrom(std::move(other)); + } + return *this; +} + +bool PyBuffer::valid() const { return ctx_ && buffer_.buffer; } + +size_t PyBuffer::byteSize() const { return buffer_.size; } + +size_t PyBuffer::elementCapacity() const { return dtypeSizeHint_ ? buffer_.size / dtypeSizeHint_ : buffer_.size; } + +bool PyBuffer::isReadOnly() const { return readOnly_; } + +void PyBuffer::release() { + if (ctx_ && buffer_.buffer) { + destroyBuffer(buffer_); + } + buffer_ = {}; + ctx_ = nullptr; + lastDtype_ = py::none(); + lastCount_ = 0; + dtypeSizeHint_ = 0; +} + +void PyBuffer::setData(const py::array& array, size_t offset) { + ensureValid(); + auto info = array.request(); + if (!isCContiguous(info)) throw std::runtime_error("array must be C-contiguous"); + size_t nbytes = static_cast(info.size) * static_cast(info.itemsize); + if (offset + nbytes > buffer_.size) throw std::out_of_range("write out of range"); + { + py::gil_scoped_release release; + setBufferData(buffer_, info.ptr, nbytes, offset); + } + lastDtype_ = array.dtype(); + lastCount_ = static_cast(info.size); + dtypeSizeHint_ = static_cast(info.itemsize ? info.itemsize : 1); +} + +py::array PyBuffer::getData(const py::object& dtypeArg, const py::object& countArg, size_t offset) const { + ensureValid(); + if (offset > buffer_.size) throw std::out_of_range("offset out of range"); + + py::dtype dtype = resolveDtype(dtypeArg); + size_t itemsize = dtype.attr("itemsize").cast(); + if (itemsize == 0) throw std::runtime_error("dtype itemsize cannot be zero"); + + size_t available = buffer_.size - offset; + size_t count = resolveCount(countArg, itemsize, available); + size_t nbytes = count * itemsize; + if (offset + nbytes > buffer_.size) throw std::out_of_range("read out of range"); + + py::array out(dtype, py::array::ShapeContainer{ static_cast(count) }); + auto info = out.request(); + { + py::gil_scoped_release release; + getBufferData(buffer_, info.ptr, nbytes, offset); + } + return out; +} + +Buffer& PyBuffer::raw() { + ensureValid(); + return buffer_; +} + +const Buffer& PyBuffer::raw() const { + ensureValid(); + return buffer_; +} + +void PyBuffer::ensureValid() const { + if (!valid()) throw std::runtime_error("Buffer has been released"); +} + +void PyBuffer::moveFrom(PyBuffer&& other) { + ctx_ = other.ctx_; + buffer_ = other.buffer_; + readOnly_ = other.readOnly_; + dtypeSizeHint_ = other.dtypeSizeHint_; + lastCount_ = other.lastCount_; + lastDtype_ = std::move(other.lastDtype_); + other.ctx_ = nullptr; + other.buffer_ = {}; + other.lastDtype_ = py::none(); +} + +py::dtype PyBuffer::resolveDtype(const py::object& dtypeArg) const { + if (!dtypeArg.is_none()) { + return py::reinterpret_borrow(dtypeArg); + } + if (!lastDtype_.is_none()) { + return py::reinterpret_borrow(lastDtype_); + } + switch (dtypeSizeHint_) { + case 2: return py::dtype::of(); + case 4: return py::dtype::of(); + case 8: return py::dtype::of(); + default: return py::dtype::of(); + } +} + +size_t PyBuffer::resolveCount(const py::object& countArg, size_t itemsize, size_t available) const { + if (!countArg.is_none()) { + return countArg.cast(); + } + if (!lastDtype_.is_none() && itemsize == dtypeSizeHint_ && lastCount_ != 0) { + return std::min(lastCount_, available / itemsize); + } + return available / itemsize; +} + +PyComputeProgram::PyComputeProgram(ComputeProgram&& prog) + : ctx_(&getVulkanContext()), program_(std::move(prog)) {} + +PyComputeProgram::~PyComputeProgram() { release(); } + +PyComputeProgram::PyComputeProgram(PyComputeProgram&& other) noexcept { moveFrom(std::move(other)); } + +PyComputeProgram& PyComputeProgram::operator=(PyComputeProgram&& other) noexcept { + if (this != &other) { + release(); + moveFrom(std::move(other)); + } + return *this; +} + +void PyComputeProgram::run(const py::iterable& readonlyBuffers, + const py::iterable& readwriteBuffers, + uint32_t groupCount, + const py::object& pushConstants) { + ensureValid(); + std::vector ro; + std::vector rw; + collectBuffers(readonlyBuffers, ro, "readonly"); + collectBuffers(readwriteBuffers, rw, "readwrite"); + if (ro.size() != program_.numRO || rw.size() != program_.numRW) { + throw std::runtime_error("buffer count does not match program layout"); + } + auto payload = preparePushConstantPayload(program_, pushConstants); + py::gil_scoped_release release; + runProgram(program_, ro, rw, groupCount, payload.data, payload.size); +} + +void PyComputeProgram::release() { + if (ctx_ && program_.pipeline) { + destroyComputeProgram(program_); + } + program_ = {}; + ctx_ = nullptr; +} + +uint32_t PyComputeProgram::readonlyCount() const { return program_.numRO; } + +uint32_t PyComputeProgram::readwriteCount() const { return program_.numRW; } + +uint32_t PyComputeProgram::pushConstantSize() const { return program_.pushConstantSize; } + +void PyComputeProgram::ensureValid() const { + if (!ctx_ || !program_.pipeline) { + throw std::runtime_error("ComputeProgram has been released"); + } +} + +void PyComputeProgram::collectBuffers(const py::iterable& items, + std::vector& out, + const char* label) { + out.clear(); + for (auto obj : items) { + try { + py::handle handle(obj); + auto* buf = handle.cast(); + if (!buf) { + throw py::cast_error("null buffer pointer"); + } + out.push_back(&buf->raw()); + } catch (const py::cast_error&) { + throw std::runtime_error(std::string("expected Buffer in ") + label + " list"); + } + } +} + +void PyComputeProgram::moveFrom(PyComputeProgram&& other) { + ctx_ = other.ctx_; + program_ = other.program_; + other.ctx_ = nullptr; + other.program_ = {}; +} + +PyWindow::PyWindow(int width, int height, const std::string& title) + : ctx_(&getVulkanContext()), window_(createWindow(width, height, title.c_str())) { + AttachWindowCallbacks(window_); +} + +PyWindow::~PyWindow() = default; + +PyWindow::PyWindow(PyWindow&& other) noexcept { moveFrom(std::move(other)); } + +PyWindow& PyWindow::operator=(PyWindow&& other) noexcept { + if (this != &other) { + moveFrom(std::move(other)); + } + return *this; +} + +bool PyWindow::isOpen() const { + ensureValid(); + return windowOpen(window_); +} + +void PyWindow::drawBuffer(const PyBuffer& buffer, uint32_t width, uint32_t height, size_t offset) { + ensureValid(); + py::gil_scoped_release release; + ::drawBuffer(window_, buffer.raw(), width, height, offset); +} + +void PyWindow::present() { + ensureValid(); + py::gil_scoped_release release; + ::drawBuffer(window_, vk::Buffer{}, window_.extent.width, window_.extent.height, 0); +} + +py::tuple PyWindow::size() { + ensureValid(); + int fbw = 0; + int fbh = 0; + glfwGetFramebufferSize(window_.wnd, &fbw, &fbh); + if (fbw > 0 && fbh > 0) { + window_.extent.width = static_cast(fbw); + window_.extent.height = static_cast(fbh); + } + return py::make_tuple(window_.extent.width, window_.extent.height); +} + +int PyWindow::format() const { + ensureValid(); + return static_cast(window_.format); +} + +void PyWindow::close() { + window_ = {}; + ctx_ = nullptr; +} + +py::tuple PyWindow::imguiBegin(const std::string& name, + const std::optional& open, + int flags) { + bindImGui(); + bool openValue = open.value_or(true); + bool visible = ImGui::Begin(name.c_str(), open ? &openValue : nullptr, flags); + return py::make_tuple(visible, open ? py::cast(openValue) : py::none()); +} + +void PyWindow::imguiEnd() { + bindImGui(); + ImGui::End(); +} + +void PyWindow::imguiText(const std::string& text) { + bindImGui(); + ImGui::TextUnformatted(text.c_str()); +} + +bool PyWindow::imguiButton(const std::string& label) { + bindImGui(); + return ImGui::Button(label.c_str()); +} + +bool PyWindow::imguiCheckbox(const std::string& label, bool value) { + bindImGui(); + bool v = value; + ImGui::Checkbox(label.c_str(), &v); + return v; +} + +int PyWindow::imguiSliderInt(const std::string& label, int value, int min, int max) { + bindImGui(); + int v = value; + ImGui::SliderInt(label.c_str(), &v, min, max); + return v; +} + +float PyWindow::imguiSliderFloat(const std::string& label, float value, float min, float max) { + bindImGui(); + float v = value; + ImGui::SliderFloat(label.c_str(), &v, min, max); + return v; +} + +void PyWindow::imguiPlotLines(const std::string& label, + py::array_t values, + int valuesOffset, + const std::string& overlayText, + float scaleMin, + float scaleMax, + py::tuple graphSize, + int stride) { + bindImGui(); + validateTupleSize(graphSize, 2, "graph_size"); + ImGui::PlotLines( + label.c_str(), + values.data(), + static_cast(values.size()), + valuesOffset, + overlayText.empty() ? nullptr : overlayText.c_str(), + scaleMin, + scaleMax, + ImVec2(graphSize[0].cast(), graphSize[1].cast()), + stride); +} + +void PyWindow::imguiScaleAllSizes(float scale) { + bindImGui(); + ImGui::GetStyle().ScaleAllSizes(scale); +} + +void PyWindow::imguiAddBackgroundText(const std::string& text, + py::tuple pos, + py::tuple color) { + bindImGui(); + validateTupleSize(pos, 2, "pos"); + validateTupleSize(color, 4, "color"); + ImGui::GetBackgroundDrawList()->AddText( + ImVec2(pos[0].cast(), pos[1].cast()), + ImColor(color[0].cast(), color[1].cast(), color[2].cast(), color[3].cast()), + text.c_str()); +} + +void PyWindow::imguiSameLine(float offsetFromStartX, float spacing) { + bindImGui(); + ImGui::SameLine(offsetFromStartX, spacing); +} + +void PyWindow::imguiSeparator() { + bindImGui(); + ImGui::Separator(); +} + +void PyWindow::imguiSpacing() { + bindImGui(); + ImGui::Spacing(); +} + +void PyWindow::imguiIndent(float indentW) { + bindImGui(); + ImGui::Indent(indentW); +} + +void PyWindow::imguiUnindent(float indentW) { + bindImGui(); + ImGui::Unindent(indentW); +} + +bool PyWindow::imguiBeginChild(const std::string& id, + const py::object& size, + bool border, + int flags) { + bindImGui(); + ImVec2 vecSize = objectToVec2(size, "size"); + return ImGui::BeginChild(id.c_str(), vecSize, border, flags); +} + +void PyWindow::imguiEndChild() { + bindImGui(); + ImGui::EndChild(); +} + +void PyWindow::imguiTextWrapped(const std::string& text) { + bindImGui(); + ImGui::TextWrapped("%s", text.c_str()); +} + +void PyWindow::imguiTextColored(py::tuple color, + const std::string& text) { + bindImGui(); + ImVec4 col = tupleToVec4(color, "color"); + ImGui::TextColored(col, "%s", text.c_str()); +} + +void PyWindow::imguiBulletText(const std::string& text) { + bindImGui(); + ImGui::BulletText("%s", text.c_str()); +} + +std::tuple PyWindow::imguiInputText(const std::string& label, + const std::string& value, + size_t bufferLength, + int flags) { + bindImGui(); + size_t minimum = value.size() + 1; + size_t capacity = bufferLength ? std::max(bufferLength, minimum) : std::max(minimum, value.size() + 256); + if (capacity == 0) capacity = 1; + std::vector buffer(capacity, '\0'); + std::copy(value.begin(), value.end(), buffer.begin()); + bool edited = ImGui::InputText(label.c_str(), buffer.data(), buffer.size(), flags); + std::string result(buffer.data()); + return std::make_tuple(edited, std::move(result)); +} + +int PyWindow::imguiInputInt(const std::string& label, int value, int step, int stepFast, int flags) { + bindImGui(); + int v = value; + ImGui::InputInt(label.c_str(), &v, step, stepFast, flags); + return v; +} + +float PyWindow::imguiInputFloat(const std::string& label, + float value, + float step, + float stepFast, + const std::string& format, + int flags) { + bindImGui(); + float v = value; + ImGui::InputFloat(label.c_str(), &v, step, stepFast, format.c_str(), flags); + return v; +} + +std::tuple PyWindow::imguiColorEdit3(const std::string& label, + py::tuple color, + int flags) { + bindImGui(); + validateTupleSize(color, 3, "color"); + float col[3] = { + color[0].cast(), + color[1].cast(), + color[2].cast() + }; + bool changed = ImGui::ColorEdit3(label.c_str(), col, flags); + return std::make_tuple(changed, py::make_tuple(col[0], col[1], col[2])); +} + +std::tuple PyWindow::imguiColorEdit4(const std::string& label, + py::tuple color, + int flags) { + bindImGui(); + validateTupleSize(color, 4, "color"); + float col[4] = { + color[0].cast(), + color[1].cast(), + color[2].cast(), + color[3].cast() + }; + bool changed = ImGui::ColorEdit4(label.c_str(), col, flags); + return std::make_tuple(changed, py::make_tuple(col[0], col[1], col[2], col[3])); +} + +bool PyWindow::imguiBeginMainMenuBar() { + bindImGui(); + return ImGui::BeginMainMenuBar(); +} + +void PyWindow::imguiEndMainMenuBar() { + bindImGui(); + ImGui::EndMainMenuBar(); +} + +bool PyWindow::imguiBeginMenuBar() { + bindImGui(); + return ImGui::BeginMenuBar(); +} + +void PyWindow::imguiEndMenuBar() { + bindImGui(); + ImGui::EndMenuBar(); +} + +bool PyWindow::imguiBeginMenu(const std::string& label, bool enabled) { + bindImGui(); + return ImGui::BeginMenu(label.c_str(), enabled); +} + +void PyWindow::imguiEndMenu() { + bindImGui(); + ImGui::EndMenu(); +} + +bool PyWindow::imguiMenuItem(const std::string& label, + const py::object& shortcut, + bool selected, + bool enabled) { + bindImGui(); + std::string shortcutValue; + const char* shortcutPtr = nullptr; + if (!shortcut.is_none()) { + shortcutValue = shortcut.cast(); + shortcutPtr = shortcutValue.c_str(); + } + return ImGui::MenuItem(label.c_str(), shortcutPtr, selected, enabled); +} + +void PyWindow::imguiOpenPopup(const std::string& strId, int popupFlags) { + bindImGui(); + ImGui::OpenPopup(strId.c_str(), popupFlags); +} + +bool PyWindow::imguiBeginPopup(const std::string& strId, int flags) { + bindImGui(); + return ImGui::BeginPopup(strId.c_str(), flags); +} + +std::tuple PyWindow::imguiBeginPopupModal(const std::string& name, + const py::object& open, + int flags) { + bindImGui(); + bool openValue = open.is_none() ? true : open.cast(); + bool visible = ImGui::BeginPopupModal(name.c_str(), open.is_none() ? nullptr : &openValue, flags); + if (open.is_none()) { + return std::make_tuple(visible, py::none()); + } + return std::make_tuple(visible, py::cast(openValue)); +} + +void PyWindow::imguiEndPopup() { + bindImGui(); + ImGui::EndPopup(); +} + +void PyWindow::imguiCloseCurrentPopup() { + bindImGui(); + ImGui::CloseCurrentPopup(); +} + +void PyWindow::imguiPushStyleColor(int idx, py::tuple color) { + bindImGui(); + ImVec4 col = tupleToVec4(color, "color"); + ImGui::PushStyleColor(idx, col); +} + +void PyWindow::imguiPopStyleColor(int count) { + bindImGui(); + ImGui::PopStyleColor(count); +} + +void PyWindow::imguiPushStyleVarFloat(int idx, float value) { + bindImGui(); + ImGui::PushStyleVar(idx, value); +} + +void PyWindow::imguiPushStyleVarVec2(int idx, py::tuple value) { + bindImGui(); + ImVec2 vec = objectToVec2(value, "value"); + ImGui::PushStyleVar(idx, vec); +} + +void PyWindow::imguiPopStyleVar(int count) { + bindImGui(); + ImGui::PopStyleVar(count); +} + +float PyWindow::imguiGetFontGlobalScale() { + bindImGui(); + return ImGui::GetIO().FontGlobalScale; +} + +void PyWindow::imguiSetFontGlobalScale(float scale) { + bindImGui(); + ImGui::GetIO().FontGlobalScale = scale; +} + +py::tuple PyWindow::imguiGetStyleColorVec4(int idx) { + bindImGui(); + ImVec4 col = ImGui::GetStyle().Colors[idx]; + return vec4ToTuple(col); +} + +void PyWindow::imguiSetStyleColorVec4(int idx, py::tuple color) { + bindImGui(); + ImVec4 col = tupleToVec4(color, "color"); + ImGui::GetStyle().Colors[idx] = col; +} + +py::tuple PyWindow::mousePosition() { + ensureValid(); + double x = 0.0; + double y = 0.0; + glfwGetCursorPos(window_.wnd, &x, &y); + return py::make_tuple(x, y); +} + +bool PyWindow::isMouseButtonPressed(int button) { + ensureValid(); + return glfwGetMouseButton(window_.wnd, button) == GLFW_PRESS; +} + +bool PyWindow::imguiWantCaptureMouse() const { + ensureValid(); + if (!window_.imguiContext) { + return false; + } + ImGui::SetCurrentContext(window_.imguiContext); + return ImGui::GetIO().WantCaptureMouse; +} + +py::tuple PyWindow::consumeScrollDelta() { + ensureValid(); + double dx = window_.scrollDeltaX; + double dy = window_.scrollDeltaY; + window_.scrollDeltaX = 0.0; + window_.scrollDeltaY = 0.0; + return py::make_tuple(dx, dy); +} + +void PyWindow::ensureValid() const { + if (!window_.wnd) { + throw std::runtime_error("Window has been closed"); + } +} + +void PyWindow::moveFrom(PyWindow&& other) { + ctx_ = other.ctx_; + window_ = std::move(other.window_); + other.ctx_ = nullptr; + if (window_.wnd) { + TFWindowDetail::RegisterScrollContext(window_.wnd, &window_); + } +} + +ImGuiContext* PyWindow::bindImGui() { + ensureValid(); + EnsureImGuiFrame(window_); + if (!window_.imguiContext) { + throw std::runtime_error("ImGui context is not initialized for this window"); + } + ImGui::SetCurrentContext(window_.imguiContext); + return window_.imguiContext; +} + +void PyWindow::validateTupleSize(const py::tuple& tpl, size_t expected, const char* name) { + if (tpl.size() != expected) { + throw std::invalid_argument(std::string("Expected tuple of size ") + std::to_string(expected) + " for " + name); + } +} + +ImVec2 PyWindow::objectToVec2(const py::object& obj, const char* name) { + if (obj.is_none()) { + return ImVec2(0.0f, 0.0f); + } + py::tuple tpl = obj.cast(); + validateTupleSize(tpl, 2, name); + return ImVec2(tpl[0].cast(), tpl[1].cast()); +} + +ImVec4 PyWindow::tupleToVec4(const py::tuple& tpl, const char* name) { + validateTupleSize(tpl, 4, name); + return ImVec4(tpl[0].cast(), + tpl[1].cast(), + tpl[2].cast(), + tpl[3].cast()); +} + +py::tuple PyWindow::vec4ToTuple(const ImVec4& vec) { + return py::make_tuple(vec.x, vec.y, vec.z, vec.w); +} + +PyComputeProgram MakeComputeProgramFromSlang(const std::string& moduleName, + const std::string& source, + const std::string& entry, + uint32_t roCount, + uint32_t rwCount, + uint32_t pushConstantSize) { + return PyComputeProgram(createComputeProgramFromSlang( + moduleName, source, entry, roCount, rwCount, pushConstantSize)); +} + +} // namespace TensorFrost diff --git a/TensorFrost/src/Definitions/VulkanInterface.h b/TensorFrost/src/Definitions/VulkanInterface.h new file mode 100644 index 00000000..8b462151 --- /dev/null +++ b/TensorFrost/src/Definitions/VulkanInterface.h @@ -0,0 +1,217 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "Backend/Vulkan.h" +#include "Backend/Window.h" + +struct ImVec2; +struct ImVec4; + +namespace TensorFrost { + +class PyBuffer { +public: + PyBuffer(size_t count, size_t dtypeSize, bool readOnly); + ~PyBuffer(); + + PyBuffer(const PyBuffer&) = delete; + PyBuffer& operator=(const PyBuffer&) = delete; + + PyBuffer(PyBuffer&& other) noexcept; + PyBuffer& operator=(PyBuffer&& other) noexcept; + + bool valid() const; + size_t byteSize() const; + size_t elementCapacity() const; + bool isReadOnly() const; + + void release(); + + void setData(const pybind11::array& array, size_t offset); + pybind11::array getData(const pybind11::object& dtypeArg, + const pybind11::object& countArg, + size_t offset) const; + + Buffer& raw(); + const Buffer& raw() const; + +private: + void ensureValid() const; + void moveFrom(PyBuffer&& other); + pybind11::dtype resolveDtype(const pybind11::object& dtypeArg) const; + size_t resolveCount(const pybind11::object& countArg, size_t itemsize, size_t available) const; + + VulkanContext* ctx_{}; + Buffer buffer_{}; + bool readOnly_{}; + size_t dtypeSizeHint_{}; + size_t lastCount_{}; + pybind11::object lastDtype_; +}; + +class PyComputeProgram { +public: + explicit PyComputeProgram(ComputeProgram&& prog); + ~PyComputeProgram(); + + PyComputeProgram(const PyComputeProgram&) = delete; + PyComputeProgram& operator=(const PyComputeProgram&) = delete; + + PyComputeProgram(PyComputeProgram&& other) noexcept; + PyComputeProgram& operator=(PyComputeProgram&& other) noexcept; + + void run(const pybind11::iterable& readonlyBuffers, + const pybind11::iterable& readwriteBuffers, + uint32_t groupCount, + const pybind11::object& pushConstants); + + void release(); + + uint32_t readonlyCount() const; + uint32_t readwriteCount() const; + uint32_t pushConstantSize() const; + +private: + void ensureValid() const; + static void collectBuffers(const pybind11::iterable& items, + std::vector& out, + const char* label); + void moveFrom(PyComputeProgram&& other); + + VulkanContext* ctx_{}; + ComputeProgram program_{}; +}; + +class PyWindow { +public: + PyWindow(int width, int height, const std::string& title); + ~PyWindow(); + + PyWindow(const PyWindow&) = delete; + PyWindow& operator=(const PyWindow&) = delete; + + PyWindow(PyWindow&& other) noexcept; + PyWindow& operator=(PyWindow&& other) noexcept; + + bool isOpen() const; + void drawBuffer(const PyBuffer& buffer, uint32_t width, uint32_t height, size_t offset); + void present(); + pybind11::tuple size(); + int format() const; + void close(); + + pybind11::tuple imguiBegin(const std::string& name, + const std::optional& open, + int flags); + void imguiEnd(); + void imguiText(const std::string& text); + bool imguiButton(const std::string& label); + bool imguiCheckbox(const std::string& label, bool value); + int imguiSliderInt(const std::string& label, int value, int min, int max); + float imguiSliderFloat(const std::string& label, float value, float min, float max); + void imguiPlotLines(const std::string& label, + pybind11::array_t values, + int valuesOffset, + const std::string& overlayText, + float scaleMin, + float scaleMax, + pybind11::tuple graphSize, + int stride); + void imguiScaleAllSizes(float scale); + void imguiAddBackgroundText(const std::string& text, + pybind11::tuple pos, + pybind11::tuple color); + void imguiSameLine(float offsetFromStartX, float spacing); + void imguiSeparator(); + void imguiSpacing(); + void imguiIndent(float indentW); + void imguiUnindent(float indentW); + bool imguiBeginChild(const std::string& id, + const pybind11::object& size, + bool border, + int flags); + void imguiEndChild(); + void imguiTextWrapped(const std::string& text); + void imguiTextColored(pybind11::tuple color, + const std::string& text); + void imguiBulletText(const std::string& text); + std::tuple imguiInputText(const std::string& label, + const std::string& value, + size_t bufferLength, + int flags); + int imguiInputInt(const std::string& label, int value, int step, int stepFast, int flags); + float imguiInputFloat(const std::string& label, + float value, + float step, + float stepFast, + const std::string& format, + int flags); + std::tuple imguiColorEdit3(const std::string& label, + pybind11::tuple color, + int flags); + std::tuple imguiColorEdit4(const std::string& label, + pybind11::tuple color, + int flags); + bool imguiBeginMainMenuBar(); + void imguiEndMainMenuBar(); + bool imguiBeginMenuBar(); + void imguiEndMenuBar(); + bool imguiBeginMenu(const std::string& label, bool enabled); + void imguiEndMenu(); + bool imguiMenuItem(const std::string& label, + const pybind11::object& shortcut, + bool selected, + bool enabled); + void imguiOpenPopup(const std::string& strId, int popupFlags); + bool imguiBeginPopup(const std::string& strId, int flags); + std::tuple imguiBeginPopupModal(const std::string& name, + const pybind11::object& open, + int flags); + void imguiEndPopup(); + void imguiCloseCurrentPopup(); + void imguiPushStyleColor(int idx, pybind11::tuple color); + void imguiPopStyleColor(int count); + void imguiPushStyleVarFloat(int idx, float value); + void imguiPushStyleVarVec2(int idx, pybind11::tuple value); + void imguiPopStyleVar(int count); + float imguiGetFontGlobalScale(); + void imguiSetFontGlobalScale(float scale); + pybind11::tuple imguiGetStyleColorVec4(int idx); + void imguiSetStyleColorVec4(int idx, pybind11::tuple color); + + pybind11::tuple mousePosition(); + bool isMouseButtonPressed(int button); + bool imguiWantCaptureMouse() const; + pybind11::tuple consumeScrollDelta(); + +private: + void ensureValid() const; + void moveFrom(PyWindow&& other); + ImGuiContext* bindImGui(); + static void validateTupleSize(const pybind11::tuple& tpl, size_t expected, const char* name); + static ImVec2 objectToVec2(const pybind11::object& obj, const char* name); + static ImVec4 tupleToVec4(const pybind11::tuple& tpl, const char* name); + static pybind11::tuple vec4ToTuple(const ImVec4& vec); + + VulkanContext* ctx_{}; + WindowContext window_{}; +}; + +PyComputeProgram MakeComputeProgramFromSlang(const std::string& moduleName, + const std::string& source, + const std::string& entry, + uint32_t roCount, + uint32_t rwCount, + uint32_t pushConstantSize = 0); + +} // namespace TensorFrost diff --git a/examples/Slang/mandelbrot.py b/examples/Slang/mandelbrot.py new file mode 100644 index 00000000..f8021ca5 --- /dev/null +++ b/examples/Slang/mandelbrot.py @@ -0,0 +1,155 @@ +from pathlib import Path +import math +import time + +import numpy as np +import TensorFrost as tf + + +def load_shader() -> str: + with open(Path(__file__).with_name("mandelbrot.slang"), "r", encoding="utf-8") as handle: + return handle.read() + + +def main() -> None: + width, height = 1024, 768 + win = tf.createWindow(width, height, "Mandelbrot (ImGui)") + win.imgui_scale_all_sizes(2.0) + win.imgui_set_font_global_scale(2.0) + + fmt = int(win.format) + is_bgra_default = fmt in (44, 50) # VK_FORMAT_B8G8R8A8_UNORM / _SRGB + + pixel_capacity = max(1, width * height) + pixel_buffer = tf.createBuffer(pixel_capacity, 4, False) + + shader_source = load_shader() + program = tf.createComputeProgramFromSlang( + "mandelbrot", + shader_source, + "csMain", + ro_count=0, + rw_count=1, + push_constant_size=32, + ) + local_size = 64 + + center = [-0.5, 0.0] + scale = 3.0 + log_scale = math.log10(scale) + pending_scroll = 0.0 + manual_iterations = 500 + auto_iterations = True + swap_rb = is_bgra_default + plot_history = np.zeros(120, dtype=np.float32) + history_index = 0 + + params = np.zeros(8, dtype=np.float32) + prev_mouse_pos = win.mouse_position() + dragging = False + prev_time = time.perf_counter() + fps = 0.0 + + def ensure_pixel_buffer(cur_width: int, cur_height: int) -> None: + nonlocal pixel_buffer, pixel_capacity + required = max(1, cur_width * cur_height) + if required != pixel_capacity: + pixel_buffer = tf.createBuffer(required, 4, False) + pixel_capacity = required + + while win.isOpen(): + now = time.perf_counter() + dt = max(now - prev_time, 1e-6) + prev_time = now + fps = fps * 0.9 + (1.0 / dt) * 0.1 + + width, height = win.size + width = max(1, int(width)) + height = max(1, int(height)) + thread_count = max(1, width * height) + group_count = max((thread_count + local_size - 1) // local_size, 1) + + ensure_pixel_buffer(width, height) + + if pending_scroll: + scroll_adjust = pending_scroll * 0.12 + log_scale = max(min(log_scale - scroll_adjust, 0.5), -4.5) + scale = pow(10.0, log_scale) + pending_scroll = 0.0 + + aspect = height / float(width) + + visible, _ = win.imgui_begin("Mandelbrot Controls", open=True) + if visible: + win.imgui_text(f"Resolution: {width} × {height}") + win.imgui_text(f"Format: {fmt}") + win.imgui_text(f"FPS: {fps:5.1f} | {dt * 1000.0:.2f} ms") + log_scale = win.imgui_slider_float("log₁₀ scale", log_scale, -4.5, 0.5) + scale = pow(10.0, log_scale) + center[0] = win.imgui_slider_float("Center X", center[0], -2.5, 1.5) + center[1] = win.imgui_slider_float("Center Y", center[1], -1.5, 1.5) + auto_iterations = win.imgui_checkbox("Auto iterations", auto_iterations) + if auto_iterations: + auto_value = max(64, int(200 + (-math.log10(scale)) * 120)) + manual_iterations = auto_value + win.imgui_text(f"Max iterations (auto): {auto_value}") + else: + manual_iterations = win.imgui_slider_int("Max iterations", manual_iterations, 64, 5000) + swap_rb = win.imgui_checkbox("BGRA swap", swap_rb) + if win.imgui_button("Reset view"): + center[:] = (-0.5, 0.0) + scale = 3.0 + log_scale = math.log10(scale) + manual_iterations = 500 + auto_iterations = True + pending_scroll = 0.0 + + plot_history[history_index % plot_history.size] = float(manual_iterations) + history_index += 1 + win.imgui_plot_lines( + "Iteration history", + plot_history, + values_offset=history_index % plot_history.size, + graph_size=(0.0, 60.0), + ) + win.imgui_end() + + xspan = scale + yspan = xspan * aspect + dx = xspan / width + dy = yspan / height + + mouse_pos = win.mouse_position() + want_capture_mouse = win.imgui_want_capture_mouse() + mouse_down = win.is_mouse_button_pressed(0) + dragging_now = mouse_down and not want_capture_mouse + if dragging and dragging_now: + delta_x = mouse_pos[0] - prev_mouse_pos[0] + delta_y = mouse_pos[1] - prev_mouse_pos[1] + center[0] -= delta_x * dx + center[1] -= delta_y * dy + dragging = dragging_now + prev_mouse_pos = mouse_pos + + xmin = center[0] - xspan * 0.5 + ymin = center[1] - yspan * 0.5 + params[0] = float(width) + params[1] = float(height) + params[2] = xmin + params[3] = ymin + params[4] = dx + params[5] = dy + params[6] = float(manual_iterations) + params[7] = 1.0 if swap_rb else 0.0 + + program.run([], [pixel_buffer], group_count, params) + + win.drawBuffer(pixel_buffer, width, height) + + _, scroll_dy = win.consume_scroll_delta() + if not want_capture_mouse: + pending_scroll += scroll_dy + + +if __name__ == "__main__": + main() diff --git a/examples/Slang/mandelbrot.slang b/examples/Slang/mandelbrot.slang new file mode 100644 index 00000000..bcfca9ec --- /dev/null +++ b/examples/Slang/mandelbrot.slang @@ -0,0 +1,66 @@ +struct MandelbrotParams { + float width; + float height; + float xmin; + float ymin; + float dx; + float dy; + float maxIter; + float swapRB; +}; + +[[vk::push_constant]] +MandelbrotParams gParams; + +[[vk::binding(0,0)]] RWStructuredBuffer Pixels : register(u0, space0); + +static float3 palette(float t) { + return sin(6.3 * float3(0.0, 0.33, 0.67) * t); +} + +[numthreads(64,1,1)] +void csMain(uint3 tid : SV_DispatchThreadID) +{ + uint idx = tid.x; + int W = (int)(gParams.width + 0.5); + int H = (int)(gParams.height + 0.5); + uint N = (uint)(W * H); + if (idx >= N) return; + + int x = int(idx % (uint)W); + int y = int(idx / (uint)W); + + float xmin = gParams.xmin; + float ymin = gParams.ymin; + float dx = gParams.dx; + float dy = gParams.dy; + int maxIter = (int)(gParams.maxIter + 0.5); + bool isBGRA = (gParams.swapRB > 0.5); + + float cx = xmin + float(x) * dx; + float cy = ymin + float(y) * dy; + + float zx = 0.0, zy = 0.0; + int i = 0; + [loop] + for (; i < maxIter; ++i) { + float zx2 = zx*zx - zy*zy + cx; + zy = 2.0*zx*zy + cy; + zx = zx2; + if (zx*zx + zy*zy > 4.0) break; + } + + float t = (float(i) - log2(log(sqrt(zx*zx + zy*zy))) + 4.0); + t = clamp(t / float(maxIter), 0.0, 1.0); + + float3 rgb = palette(t); + float4 c = float4(rgb, 1.0); + + uint r = (uint)round(saturate(c.r) * 255.0); + uint g = (uint)round(saturate(c.g) * 255.0); + uint b = (uint)round(saturate(c.b) * 255.0); + uint a = (uint)round(saturate(c.a) * 255.0); + uint packed = isBGRA ? (b | (g<<8) | (r<<16) | (a<<24)) + : (r | (g<<8) | (b<<16) | (a<<24)); + Pixels[idx] = packed; +} diff --git a/examples/counting_sort/__main__.py b/examples/counting_sort/__main__.py new file mode 100644 index 00000000..2a77bd09 --- /dev/null +++ b/examples/counting_sort/__main__.py @@ -0,0 +1,273 @@ +from __future__ import annotations + +import argparse +import math +import time +from collections import deque +from pathlib import Path + +import numpy as np +import TensorFrost as tf + +try: + from .sort import CountingSort +except ImportError: + import sys + + _CURRENT_DIR = Path(__file__).resolve().parent + if str(_CURRENT_DIR) not in sys.path: + sys.path.insert(0, str(_CURRENT_DIR)) + + from sort import CountingSort + + +_STAGE_NAMES = ( + "histogram", + "block_sum", + "block_prefix_stage1", + "block_prefix_stage2", + "scatter", +) + + +def _select_backend() -> None: + if hasattr(tf, "initialize"): + backend = getattr(tf, "vulkan", None) + if backend is None: + raise RuntimeError("TensorFrost Vulkan backend is unavailable on this build") + tf.initialize(backend) + + +def _format_rate(value: float, unit: str) -> str: + if value <= 0.0 or not math.isfinite(value): + return f"0 {unit}" + + prefixes = ("", "K", "M", "G", "T", "P") + magnitude = 0 + while value >= 1000.0 and magnitude < len(prefixes) - 1: + value /= 1000.0 + magnitude += 1 + + return f"{value:6.2f} {prefixes[magnitude]}{unit}" + + +def _format_stage_name(name: str) -> str: + return name.replace("_", " ").title() + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Interactive counting sort demo running continuously on the Vulkan backend.", + ) + parser.add_argument("--size", type=int, default=1 << 22, help="Number of keys to sort") + parser.add_argument( + "--max-value", + type=int, + default=1 << 20, + help="Exclusive upper bound for the generated key values", + ) + parser.add_argument("--window-width", type=int, default=960, help="Window width in pixels") + parser.add_argument("--window-height", type=int, default=540, help="Window height in pixels") + parser.add_argument("--font-scale", type=float, default=1.6, help="Global ImGui font scale") + parser.add_argument("--history", type=int, default=240, help="Number of samples stored for metrics history") + parser.add_argument("--seed", type=int, default=1337, help="Seed for the random data generator") + args = parser.parse_args() + + _select_backend() + + count = max(0, int(args.size)) + max_value = max(1, int(args.max_value)) + window_width = max(320, int(args.window_width)) + window_height = max(240, int(args.window_height)) + history_length = max(1, int(args.history)) + font_scale = max(0.1, float(args.font_scale)) if args.font_scale > 0 else 0.0 + + rng = np.random.default_rng(int(args.seed)) + keys = ( + rng.integers(0, max_value, size=count, dtype=np.uint32) + if count + else np.empty(0, dtype=np.uint32) + ) + values = ( + rng.integers(0, 1 << 31, size=count, dtype=np.uint32) + if count + else None + ) + + frame_times = deque(maxlen=history_length) + sort_times = deque(maxlen=history_length) + stage_totals_overall = {name: 0.0 for name in _STAGE_NAMES} + + validated = False + validation_ok = False + validation_message = "" + + sort_count = 0 + total_kernel_time = 0.0 + + sorter = CountingSort(max_value=max_value) + + window_title = f"TensorFrost Counting Sort ({count:,} elements)" + window = tf.createWindow(window_width, window_height, window_title) + + if font_scale > 0.0 and window is not None: + window.imgui_scale_all_sizes(font_scale) + window.imgui_set_font_global_scale(font_scale) + + start_time = time.perf_counter() + last_frame_time = start_time + + while window is not None and window.isOpen(): + now = time.perf_counter() + dt = now - last_frame_time + last_frame_time = now + + if dt > 0.0: + frame_times.append(dt) + frame_time_total = sum(frame_times) + fps = (len(frame_times) / frame_time_total) if frame_times and frame_time_total > 0.0 else 0.0 + + do_validate = not validated + return_arrays = do_validate + _keys_out, _vals_out = sorter.sort( + keys, + values, + collect_stage_timings=True, + validate=do_validate, + return_arrays=return_arrays, + ) + + stage_timings = sorter.last_stage_timings or {} + total_pass_time = float(stage_timings.get("total_pass", 0.0)) + stage_sum_time = float(sum(v for k, v in stage_timings.items() if k != "total_pass")) + kernel_time = total_pass_time if total_pass_time > 0.0 else stage_sum_time + + if sort_times.maxlen == len(sort_times): + sort_times.popleft() + sort_times.append(kernel_time) + + sort_count += 1 + total_kernel_time += kernel_time + for name in _STAGE_NAMES: + stage_totals_overall[name] += stage_timings.get(name, 0.0) + + if do_validate: + errors = int(getattr(sorter, "last_validation_errors", 0) or 0) + gpu_validation_ok = (errors == 0) + + cpu_validation_ok = True + cpu_messages = [] + if return_arrays and count: + reference_keys = np.sort(keys, kind="stable") + if not np.array_equal(_keys_out, reference_keys): + cpu_validation_ok = False + cpu_messages.append("key mismatch") + + validation_ok = gpu_validation_ok and cpu_validation_ok + if validation_ok: + validation_message = "GPU + NumPy validation passed." + else: + failure_reasons = [] + if not gpu_validation_ok: + failure_reasons.append(f"GPU reported {errors} out-of-order pairs") + if not cpu_validation_ok: + reason = ", ".join(cpu_messages) if cpu_messages else "NumPy comparison failed" + failure_reasons.append(reason) + validation_message = "Validation failed: " + "; ".join(failure_reasons) + + validated = True + + window_avg_sort = (sum(sort_times) / len(sort_times)) if sort_times else 0.0 + last_sort_ms = kernel_time * 1000.0 + avg_sort_ms = window_avg_sort * 1000.0 + + last_sort_rate = (1.0 / kernel_time) if kernel_time > 0.0 else 0.0 + window_sort_rate = (1.0 / window_avg_sort) if window_avg_sort > 0.0 else 0.0 + overall_avg_sort = (total_kernel_time / sort_count) if sort_count else 0.0 + overall_sort_rate = (1.0 / overall_avg_sort) if overall_avg_sort > 0.0 else 0.0 + + last_elements_per_sec = (count / kernel_time) if count and kernel_time > 0.0 else 0.0 + window_elements_per_sec = (count / window_avg_sort) if count and window_avg_sort > 0.0 else 0.0 + + history_data = np.array(sort_times, dtype=np.float32) * 1000.0 + + visible, _ = window.imgui_begin("Counting Sort Performance", open=None, flags=0) + if visible: + window.imgui_text(f"Elements: {count:,}") + window.imgui_text(f"Max value: {max_value:,}") + window.imgui_separator() + window.imgui_text(f"Frame time: {dt * 1000.0:7.3f} ms") + window.imgui_text(f"Average FPS: {fps:6.1f}") + window.imgui_separator() + + if sort_count: + window.imgui_text(f"Last kernel time: {last_sort_ms:7.3f} ms") + if total_pass_time > 0.0: + window.imgui_text(f"Total pass time (last): {total_pass_time * 1000.0:7.3f} ms") + if window_avg_sort > 0.0: + window.imgui_text(f"Window avg kernel: {avg_sort_ms:7.3f} ms") + if overall_avg_sort > 0.0: + window.imgui_text(f"Overall avg kernel: {overall_avg_sort * 1000.0:7.3f} ms") + + window.imgui_spacing() + window.imgui_text(f"Sorts/sec (last): {last_sort_rate:8.2f}") + window.imgui_text(f"Sorts/sec (window): {window_sort_rate:8.2f}") + window.imgui_text(f"Sorts/sec (overall): {overall_sort_rate:8.2f}") + + if count: + window.imgui_spacing() + window.imgui_text( + f"Elements/sec (last): {_format_rate(last_elements_per_sec, ' elements/s')}" + ) + window.imgui_text( + f"Elements/sec (window): {_format_rate(window_elements_per_sec, ' elements/s')}" + ) + + window.imgui_spacing() + window.imgui_text(f"Total sorts: {sort_count:,}") + else: + if count: + window.imgui_text("Waiting for first GPU sort...") + else: + window.imgui_text("No elements to sort.") + + if validated: + color = (0.20, 0.80, 0.40, 1.0) if validation_ok else (0.92, 0.35, 0.32, 1.0) + window.imgui_text_colored(color, validation_message) + elif count: + window.imgui_text("Validating against CPU reference...") + + if history_data.size: + max_plot = float(np.max(history_data)) if history_data.size else 1.0 + max_plot = max(1.0, max_plot * 1.1) + window.imgui_plot_lines( + "Kernel time history (ms)", + history_data, + values_offset=0, + overlay_text=f"{history_data.size} sample history", + scale_min=0.0, + scale_max=max_plot, + graph_size=(0.0, 140.0), + stride=4, + ) + + if stage_timings: + window.imgui_separator() + window.imgui_text("Kernel stages (last run):") + for name, duration in sorted(stage_timings.items(), key=lambda item: item[1], reverse=True): + if name == "total_pass": + continue + window.imgui_text(f"{_format_stage_name(name)}: {duration * 1000.0:7.3f} ms") + + window.imgui_spacing() + window.imgui_text("Kernel stages (overall avg):") + for name in sorted(stage_totals_overall.keys()): + avg_duration = (stage_totals_overall[name] / sort_count) if sort_count else 0.0 + window.imgui_text(f"{_format_stage_name(name)}: {avg_duration * 1000.0:7.3f} ms") + + window.imgui_end() + window.present() + + +if __name__ == "__main__": + main() diff --git a/examples/counting_sort/shaders/block_prefix_stage2.slang b/examples/counting_sort/shaders/block_prefix_stage2.slang new file mode 100644 index 00000000..ca4d7275 --- /dev/null +++ b/examples/counting_sort/shaders/block_prefix_stage2.slang @@ -0,0 +1,24 @@ +struct BlockPrefixStage2Params { + uint groupCount; +}; + +[[vk::push_constant]] +BlockPrefixStage2Params gParams; + +[[vk::binding(0,0)]] StructuredBuffer BlockGroupTotals : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer BlockGroupPrefix : register(u1, space0); + +[numthreads(64, 1, 1)] +void csBlockPrefixStage2(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + if (dispatchThreadID.x != 0) + return; + + uint groupCount = gParams.groupCount; + uint prefix = 0; + for (uint group = 0; group < groupCount; ++group) + { + BlockGroupPrefix[group] = prefix; + prefix += BlockGroupTotals[group]; + } +} diff --git a/examples/counting_sort/shaders/block_sum.slang b/examples/counting_sort/shaders/block_sum.slang new file mode 100644 index 00000000..c3312fc2 --- /dev/null +++ b/examples/counting_sort/shaders/block_sum.slang @@ -0,0 +1,37 @@ +struct SegmentScanParams { + uint segmentCount; + uint segmentSpan; + uint inputLimit; +}; + +[[vk::push_constant]] +SegmentScanParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer InputData : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer PrefixOutput : register(u1, space0); +[[vk::binding(2,0)]] RWStructuredBuffer SegmentTotals : register(u2, space0); + +[numthreads(64, 1, 1)] +void csSegmentScan(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint segment = dispatchThreadID.x; + if (segment >= gParams.segmentCount) + return; + + uint start = segment * gParams.segmentSpan; + if (start >= gParams.inputLimit) + { + SegmentTotals[segment] = 0; + return; + } + + uint endIndex = min(start + gParams.segmentSpan, gParams.inputLimit); + uint prefix = 0; + for (uint idx = start; idx < endIndex; ++idx) + { + PrefixOutput[idx] = prefix; + prefix += InputData[idx]; + } + + SegmentTotals[segment] = prefix; +} diff --git a/examples/counting_sort/shaders/histogram_rank.slang b/examples/counting_sort/shaders/histogram_rank.slang new file mode 100644 index 00000000..a20f11cf --- /dev/null +++ b/examples/counting_sort/shaders/histogram_rank.slang @@ -0,0 +1,29 @@ +static const uint GROUP_SIZE = CS_GROUP_SIZE; + +struct HistogramParams { + uint elementCount; + uint valueCount; +}; + +[[vk::push_constant]] +HistogramParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer KeysIn : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer Histogram : register(u1, space0); +[[vk::binding(2,0)]] RWStructuredBuffer LocalRank : register(u2, space0); + +[numthreads(GROUP_SIZE, 1, 1)] +void csHistogramAndRank(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint index = dispatchThreadID.x; + if (index >= gParams.elementCount) + return; + + uint value = KeysIn[index]; + if (value >= gParams.valueCount) + return; + + uint rank = 0; + InterlockedAdd(Histogram[value], 1u, rank); + LocalRank[index] = rank; +} diff --git a/examples/counting_sort/shaders/scatter.slang b/examples/counting_sort/shaders/scatter.slang new file mode 100644 index 00000000..ec1bc039 --- /dev/null +++ b/examples/counting_sort/shaders/scatter.slang @@ -0,0 +1,47 @@ +static const uint GROUP_SIZE = CS_GROUP_SIZE; + +struct ScatterParams { + uint elementCount; + uint valueCount; + uint blockSpan; + uint blocksPerGroup; + uint hasValues; +}; + +[[vk::push_constant]] +ScatterParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer KeysIn : register(t0, space0); +[[vk::binding(1,0)]] StructuredBuffer ValuesIn : register(t1, space0); +[[vk::binding(2,0)]] StructuredBuffer LocalRank : register(t2, space0); +[[vk::binding(3,0)]] StructuredBuffer LocalPrefix : register(t3, space0); +[[vk::binding(4,0)]] StructuredBuffer BlockPrefixStage1 : register(t4, space0); +[[vk::binding(5,0)]] StructuredBuffer BlockGroupPrefix : register(t5, space0); +[[vk::binding(6,0)]] RWStructuredBuffer KeysOut : register(u6, space0); +[[vk::binding(7,0)]] RWStructuredBuffer ValuesOut : register(u7, space0); + +[numthreads(GROUP_SIZE, 1, 1)] +void csScatter(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint index = dispatchThreadID.x; + uint elementCount = gParams.elementCount; + if (index >= elementCount) + return; + + uint value = KeysIn[index]; + if (value >= gParams.valueCount) + return; + + uint rank = LocalRank[index]; + uint block = value / gParams.blockSpan; + uint group = block / max(gParams.blocksPerGroup, 1u); + uint localOffset = LocalPrefix[value]; + uint blockOffset = BlockPrefixStage1[block] + BlockGroupPrefix[group]; + uint destination = blockOffset + localOffset + rank; + + KeysOut[destination] = value; + if (gParams.hasValues != 0) + { + ValuesOut[destination] = ValuesIn[index]; + } +} diff --git a/examples/counting_sort/shaders/validate_sorted.slang b/examples/counting_sort/shaders/validate_sorted.slang new file mode 100644 index 00000000..524e5bf7 --- /dev/null +++ b/examples/counting_sort/shaders/validate_sorted.slang @@ -0,0 +1,27 @@ +static const uint GROUP_SIZE = CS_GROUP_SIZE; + +struct ValidateParams { + uint elementCount; +}; + +[[vk::push_constant]] +ValidateParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer Keys : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer Errors : register(u1, space0); + +[numthreads(GROUP_SIZE, 1, 1)] +void csValidate(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint index = dispatchThreadID.x; + uint n = gParams.elementCount; + if (n < 2 || index >= n - 1) + return; + + uint a = Keys[index]; + uint b = Keys[index + 1u]; + if (a > b) + { + InterlockedAdd(Errors[0], 1u); + } +} diff --git a/examples/counting_sort/sort.py b/examples/counting_sort/sort.py new file mode 100644 index 00000000..1b8c8b0c --- /dev/null +++ b/examples/counting_sort/sort.py @@ -0,0 +1,362 @@ +from __future__ import annotations + +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Optional, Tuple + +import numpy as np + +import TensorFrost as tf + +__all__ = ["CountingSort", "counting_sort"] + + +def _dispatch_groups(work_items: int, threads_per_group: int) -> int: + if work_items <= 0: + return 0 + return (work_items + threads_per_group - 1) // threads_per_group + + +def _prepare_keys(keys: np.ndarray) -> np.ndarray: + array = np.asarray(keys) + if array.ndim != 1: + raise ValueError("counting_sort expects a 1D array of keys") + if array.dtype != np.uint32: + array = array.astype(np.uint32, copy=False) + return array + + +def _prepare_values(values: np.ndarray) -> Tuple[np.ndarray, np.dtype]: + array = np.asarray(values) + if array.ndim != 1: + raise ValueError("counting_sort expects a 1D array of values when provided") + dtype = array.dtype + if dtype not in (np.uint32, np.int32, np.float32): + raise TypeError("values must have dtype uint32, int32, or float32") + return array, dtype + + +_SHADER_DIR = Path(__file__).resolve().parent / "shaders" + + +def _load_shader_source(filename: str) -> str: + return (_SHADER_DIR / filename).read_text(encoding="utf-8") + + +@dataclass(frozen=True) +class _SorterKey: + max_value: int + block_span: int + group_size: int + blocks_per_group: int + + +class CountingSort: + """GPU counting sort built on top of TensorFrost compute and Slang shaders.""" + + def __init__( + self, + *, + max_value: int, + block_span: int = 256, + group_size: int = 128, + blocks_per_group: int = 256, + ) -> None: + if max_value <= 0: + raise ValueError("max_value must be positive") + if block_span <= 0: + raise ValueError("block_span must be positive") + if group_size <= 0 or group_size > 1024: + raise ValueError("group_size must be within (0, 1024]") + if blocks_per_group <= 0: + raise ValueError("blocks_per_group must be positive") + + self.value_count = int(max_value) + self.block_span = int(block_span) + self.group_size = int(group_size) + self.blocks_per_group = int(blocks_per_group) + self.block_count = max((self.value_count + self.block_span - 1) // self.block_span, 1) + self.block_group_count = max((self.block_count + self.blocks_per_group - 1) // self.blocks_per_group, 1) + + self.last_stage_timings: Optional[Dict[str, float]] = None + self.last_validation_errors: Optional[int] = None + + def inject_defines(filename: str) -> str: + defines = [ + f"#define CS_GROUP_SIZE {self.group_size}u", + f"#define CS_BLOCK_SPAN {self.block_span}u", + ] + return "\n".join(defines) + "\n" + _load_shader_source(filename) + + self._dummy_values_buffer = tf.createBuffer(1, 4, False) + + self._histogram_program = tf.createComputeProgramFromSlang( + "count_histogram_rank", + inject_defines("histogram_rank.slang"), + "csHistogramAndRank", + ro_count=1, + rw_count=2, + push_constant_size=8, + ) + self._segment_scan_program = tf.createComputeProgramFromSlang( + "count_segment_scan", + inject_defines("block_sum.slang"), + "csSegmentScan", + ro_count=1, + rw_count=2, + push_constant_size=12, + ) + self._block_prefix_stage2_program = tf.createComputeProgramFromSlang( + "count_block_prefix_stage2", + inject_defines("block_prefix_stage2.slang"), + "csBlockPrefixStage2", + ro_count=1, + rw_count=1, + push_constant_size=4, + ) + self._scatter_program = tf.createComputeProgramFromSlang( + "count_scatter", + inject_defines("scatter.slang"), + "csScatter", + ro_count=6, + rw_count=2, + push_constant_size=20, + ) + self._validate_program = tf.createComputeProgramFromSlang( + "count_validate", + inject_defines("validate_sorted.slang"), + "csValidate", + ro_count=1, + rw_count=1, + push_constant_size=4, + ) + + def sort( + self, + keys: np.ndarray, + values: Optional[np.ndarray] = None, + *, + collect_stage_timings: bool = False, + validate: bool = False, + return_arrays: bool = False, + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + keys_array = _prepare_keys(keys) + element_count = int(keys_array.shape[0]) + + if values is not None: + values_array, values_dtype = _prepare_values(values) + if values_array.shape[0] != element_count: + raise ValueError("values must have the same length as keys") + else: + values_array = None + values_dtype = None + + if element_count == 0: + self.last_stage_timings = {} if collect_stage_timings else None + if validate: + self.last_validation_errors = 0 + if values_array is None: + return keys_array.copy(), None + return keys_array.copy(), values_array.copy() + + if np.max(keys_array) >= self.value_count: + raise ValueError(f"keys must be within [0, {self.value_count}) for this sorter") + + if collect_stage_timings: + stage_totals: Dict[str, float] = { + "histogram": 0.0, + "block_sum": 0.0, + "block_prefix_stage1": 0.0, + "block_prefix_stage2": 0.0, + "scatter": 0.0, + } + else: + stage_totals = {} + + key_in_buffer = tf.createBuffer(max(element_count, 1), 4, False) + key_in_buffer.setData(keys_array) + key_out_buffer = tf.createBuffer(max(element_count, 1), 4, False) + + if values_array is not None: + val_in_buffer = tf.createBuffer(max(element_count, 1), 4, False) + val_in_buffer.setData(values_array) + val_out_buffer = tf.createBuffer(max(element_count, 1), 4, False) + else: + dummy = self._dummy_values_buffer + val_in_buffer = dummy + val_out_buffer = dummy + + histogram_buffer = tf.createBuffer(max(self.value_count, 1), 4, False) + z_hist = np.zeros(self.value_count, dtype=np.uint32) + histogram_buffer.setData(z_hist) + + local_rank_buffer = tf.createBuffer(max(element_count, 1), 4, False) + local_prefix_buffer = tf.createBuffer(max(self.value_count, 1), 4, False) + block_totals_buffer = tf.createBuffer(max(self.block_count, 1), 4, False) + block_prefix_stage1_buffer = tf.createBuffer(max(self.block_count, 1), 4, False) + block_group_totals_buffer = tf.createBuffer(max(self.block_group_count, 1), 4, False) + block_group_prefix_buffer = tf.createBuffer(max(self.block_group_count, 1), 4, False) + + hist_params = np.zeros(2, dtype=np.uint32) + hist_params[0] = np.uint32(element_count) + hist_params[1] = np.uint32(self.value_count) + + segment_params_hist = np.zeros(3, dtype=np.uint32) + segment_params_hist[0] = np.uint32(self.block_count) + segment_params_hist[1] = np.uint32(self.block_span) + segment_params_hist[2] = np.uint32(self.value_count) + + stage1_params = np.zeros(3, dtype=np.uint32) + stage1_params[0] = np.uint32(self.block_group_count) + stage1_params[1] = np.uint32(self.blocks_per_group) + stage1_params[2] = np.uint32(self.block_count) + + stage2_params = np.zeros(1, dtype=np.uint32) + stage2_params[0] = np.uint32(self.block_group_count) + + scatter_params = np.zeros(5, dtype=np.uint32) + scatter_params[0] = np.uint32(element_count) + scatter_params[1] = np.uint32(self.value_count) + scatter_params[2] = np.uint32(self.block_span) + scatter_params[3] = np.uint32(self.blocks_per_group) + scatter_params[4] = np.uint32(1 if values_array is not None else 0) + + histogram_groups = _dispatch_groups(element_count, self.group_size) + block_sum_groups = _dispatch_groups(self.block_count, 64) + block_prefix_stage1_groups = _dispatch_groups(self.block_group_count, 64) + block_prefix_stage2_groups = _dispatch_groups(self.block_group_count, 64) + scatter_groups = _dispatch_groups(element_count, self.group_size) + + total_start = time.perf_counter() if collect_stage_timings else None + + start = time.perf_counter() if collect_stage_timings else None + self._histogram_program.run( + [key_in_buffer], + [histogram_buffer, local_rank_buffer], + histogram_groups, + hist_params, + ) + if collect_stage_timings and start is not None: + stage_totals["histogram"] += time.perf_counter() - start + + start = time.perf_counter() if collect_stage_timings else None + self._segment_scan_program.run( + [histogram_buffer], + [local_prefix_buffer, block_totals_buffer], + block_sum_groups, + segment_params_hist, + ) + if collect_stage_timings and start is not None: + stage_totals["block_sum"] += time.perf_counter() - start + + start = time.perf_counter() if collect_stage_timings else None + self._segment_scan_program.run( + [block_totals_buffer], + [block_prefix_stage1_buffer, block_group_totals_buffer], + block_prefix_stage1_groups, + stage1_params, + ) + if collect_stage_timings and start is not None: + stage_totals["block_prefix_stage1"] += time.perf_counter() - start + + start = time.perf_counter() if collect_stage_timings else None + self._block_prefix_stage2_program.run( + [block_group_totals_buffer], + [block_group_prefix_buffer], + block_prefix_stage2_groups, + stage2_params, + ) + if collect_stage_timings and start is not None: + stage_totals["block_prefix_stage2"] += time.perf_counter() - start + + start = time.perf_counter() if collect_stage_timings else None + self._scatter_program.run( + [ + key_in_buffer, + val_in_buffer, + local_rank_buffer, + local_prefix_buffer, + block_prefix_stage1_buffer, + block_group_prefix_buffer, + ], + [key_out_buffer, val_out_buffer], + scatter_groups, + scatter_params, + ) + if collect_stage_timings and start is not None: + stage_totals["scatter"] += time.perf_counter() - start + + if collect_stage_timings and total_start is not None: + stage_totals["total_pass"] = time.perf_counter() - total_start + + if validate: + validate_params = np.zeros(1, dtype=np.uint32) + validate_params[0] = np.uint32(element_count) + + error_buf = tf.createBuffer(1, 4, False) + error_zero = np.zeros(1, dtype=np.uint32) + error_buf.setData(error_zero) + + self._validate_program.run( + [key_out_buffer], + [error_buf], + _dispatch_groups(element_count, self.group_size), + validate_params, + ) + + error_count = int(error_buf.getData(np.dtype(np.uint32), 1)[0]) + self.last_validation_errors = error_count + else: + self.last_validation_errors = None + + if return_arrays: + sorted_keys = key_out_buffer.getData(np.dtype(np.uint32), element_count) + if values_array is not None and values_dtype is not None: + sorted_values = val_out_buffer.getData(values_dtype, element_count) + else: + sorted_values = None + else: + sorted_keys = np.empty(0, dtype=np.uint32) + sorted_values = None if values_array is None or values_dtype is None else np.empty(0, dtype=values_dtype) + + self.last_stage_timings = stage_totals if collect_stage_timings else None + return sorted_keys, sorted_values + + +_SORTER_CACHE: Dict[_SorterKey, CountingSort] = {} + + +def _get_sorter(max_value: int, block_span: int, group_size: int, blocks_per_group: int) -> CountingSort: + key = _SorterKey(max_value, block_span, group_size, blocks_per_group) + sorter = _SORTER_CACHE.get(key) + if sorter is None: + sorter = CountingSort( + max_value=max_value, + block_span=block_span, + group_size=group_size, + blocks_per_group=blocks_per_group, + ) + _SORTER_CACHE[key] = sorter + return sorter + + +def counting_sort( + keys: np.ndarray, + values: Optional[np.ndarray] = None, + *, + max_value: int, + block_span: int = 256, + group_size: int = 128, + blocks_per_group: int = 256, +) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """Convenience wrapper around :class:`CountingSort`.""" + + if max_value is None: + raise ValueError("max_value must be provided for counting_sort") + + sorter = _get_sorter(int(max_value), int(block_span), int(group_size), int(blocks_per_group)) + sorted_keys, sorted_values = sorter.sort(keys, values) + if values is None: + return sorted_keys + return sorted_keys, sorted_values diff --git a/examples/debug.py b/examples/debug.py index 8037c7bb..99213136 100644 --- a/examples/debug.py +++ b/examples/debug.py @@ -1,46 +1,57 @@ import numpy as np import TensorFrost as tf -tf.initialize(tf.cpu) - -blur_d = 16 -blur_r = blur_d * 0.5 - -def kernel(r): - #return 1.0 - return tf.exp(-0.5 * (r / blur_r)**2.0) / (blur_r * np.sqrt(2.0 * np.pi)) - -# def blurfunc(): -# img = tf.func_input([-1, -1], tf.float32) -# with tf.loop(-blur_d, blur_d+1, 1) as k: -# blur_h += img[i+k, j] * kernel(tf.float(k)) -# with tf.loop(-blur_d, blur_d+1, 1) as k: -# blur_v += blur_h[i, j+k] * kernel(tf.float(k)) -# return blur_v -# -# def blur(): -# img = tf.input([-1, -1, -1], tf.float32) -# N, M, C = img.shape -# blur_h = tf.zeros(img.shape, tf.float32) -# blur_v = tf.zeros(img.shape, tf.float32) -# i, j, ch = img.indices -# -# tf.vmap(inputs=[img], map=[C], func=blurfunc); -# -# return blur_v - -@tf.compile -def blur(img: tf.Arg([-1, -1, -1], tf.float32)): - blur_h = tf.zeros(img.shape, tf.float32) - blur_v = tf.zeros(img.shape, tf.float32) - i, j, ch = img.indices - - #horizontal blur - with tf.loop(-blur_d, blur_d+1, 1) as k: - blur_h += img[i+k, j, ch] * kernel(tf.float(k)) - - #vertical blur - with tf.loop(-blur_d, blur_d+1, 1) as k: - blur_v += blur_h[i, j+k, ch] * kernel(tf.float(k)) - - return blur_v +_SLANG = r""" +struct FillParams { + float4 color; +}; + +[[vk::push_constant]] +FillParams gParams; + +[[vk::binding(0,0)]] RWStructuredBuffer Pixels : register(u0, space0); + +[numthreads(64,1,1)] +void csMain(uint3 tid : SV_DispatchThreadID) +{ + uint idx = tid.x; + if (idx >= Pixels.length()) return; + + float4 c = saturate(gParams.color); + uint r = (uint)round(c.r * 255.0); + uint g = (uint)round(c.g * 255.0); + uint b = (uint)round(c.b * 255.0); + uint a = (uint)round(c.a * 255.0); + Pixels[idx] = r | (g << 8) | (b << 16) | (a << 24); +} +""" + + +def main() -> None: + width = height = 512 + local_size = 64 + thread_count = width * height + group_count = max((thread_count + local_size - 1) // local_size, 1) + + window = tf.createWindow(width, height, "TensorFrost Debug Fill") + pixel_buffer = tf.createBuffer(thread_count, 4, False) + program = tf.createComputeProgramFromSlang( + "debug_fill", + _SLANG, + "csMain", + ro_count=0, + rw_count=1, + push_constant_size=16, + ) + + color = np.array([0.15, 0.45, 0.95, 1.0], dtype=np.float32) + + while window.isOpen(): + program.run([], [pixel_buffer], group_count, color) + window.drawBuffer(pixel_buffer, width, height) + + window.close() + + +if __name__ == "__main__": + main() diff --git a/examples/imgui_showcase.py b/examples/imgui_showcase.py new file mode 100644 index 00000000..d3f52222 --- /dev/null +++ b/examples/imgui_showcase.py @@ -0,0 +1,270 @@ +"""Interactive ImGui showcase demonstrating the TensorFrost Vulkan window helpers. + +This example exercises a large portion of the Python ImGui bindings that are +exposed through ``TensorFrost.Window``. +""" + +from __future__ import annotations + +import math +import time +from collections import deque + +import numpy as np +import TensorFrost as tf + +# ImGui enums we need inside the sample. They mirror the values from imgui.h. +IMGUI_WINDOW_FLAGS_MENU_BAR = 1 << 3 +IMGUI_COL_WINDOW_BG = 2 +IMGUI_COL_BUTTON = 21 +IMGUI_COL_BUTTON_HOVERED = 22 +IMGUI_STYLEVAR_WINDOW_PADDING = 2 +IMGUI_STYLEVAR_FRAME_ROUNDING = 12 +IMGUI_STYLEVAR_ITEM_SPACING = 14 + + +def main() -> None: + width, height = 960, 600 + window = tf.createWindow(width, height, "TensorFrost ImGui Showcase") + + sample_history: deque[float] = deque(maxlen=512) + frame_times: deque[float] = deque(maxlen=240) + + start_time = time.perf_counter() + last_time = start_time + total_time = 0.0 + + state = { + "animate": True, + "show_plot": True, + "wave_speed": 1.0, + "wave_scale": 1.0, + "sample_count": 180, + "greeting": "Hello from TensorFrost!", + "accent": (0.25, 0.62, 0.98, 1.0), + "theme": "dark", + "font_scale": 2.0, + } + + themes = { + "dark": { + IMGUI_COL_WINDOW_BG: (0.1, 0.12, 0.16, 1.0), + IMGUI_COL_BUTTON: (0.27, 0.44, 0.85, 1.0), + IMGUI_COL_BUTTON_HOVERED: (0.36, 0.53, 0.92, 1.0), + }, + "light": { + IMGUI_COL_WINDOW_BG: (0.95, 0.96, 1.0, 1.0), + IMGUI_COL_BUTTON: (0.60, 0.78, 1.0, 1.0), + IMGUI_COL_BUTTON_HOVERED: (0.47, 0.66, 0.94, 1.0), + }, + "retro": { + IMGUI_COL_WINDOW_BG: (0.12, 0.10, 0.08, 1.0), + IMGUI_COL_BUTTON: (0.93, 0.74, 0.28, 1.0), + IMGUI_COL_BUTTON_HOVERED: (0.98, 0.84, 0.39, 1.0), + }, + } + + def apply_theme(name: str) -> None: + colors = themes[name] + for col_idx, value in colors.items(): + window.imgui_set_style_color_vec4(col_idx, value) + + apply_theme(state["theme"]) + window.imgui_scale_all_sizes(2.0) + window.imgui_set_font_global_scale(state["font_scale"]) + + while window.isOpen(): + now = time.perf_counter() + dt = now - last_time + last_time = now + total_time += dt + + frame_times.append(dt) + fps = len(frame_times) / sum(frame_times) if frame_times else 0.0 + + if state["animate"]: + sample = math.sin(total_time * state["wave_speed"]) * state["wave_scale"] + sample_history.append(sample) + else: + # Keep the history flat when paused so the plot stays visible. + if sample_history: + sample_history.append(sample_history[-1]) + else: + sample_history.append(0.0) + + # Keep only the user-requested number of samples from the history. + history_count = max(2, min(state["sample_count"], len(sample_history))) + history_array = np.array(list(sample_history)[-history_count:], dtype=np.float32) + + width, height = window.size + + if window.imgui_begin_main_menu_bar(): + if window.imgui_begin_menu("Theme"): + for name in themes: + if window.imgui_menu_item( + name.title(), shortcut=None, selected=state["theme"] == name, enabled=True + ): + state["theme"] = name + apply_theme(name) + window.imgui_end_menu() + + if window.imgui_begin_menu("View"): + if window.imgui_menu_item( + "Toggle animation", shortcut="A", selected=state["animate"], enabled=True + ): + state["animate"] = not state["animate"] + if window.imgui_menu_item( + "Show plot", shortcut="P", selected=state["show_plot"], enabled=True + ): + state["show_plot"] = not state["show_plot"] + if window.imgui_menu_item( + "Reset font scale", shortcut="Ctrl+0", selected=False, enabled=True + ): + state["font_scale"] = 2.0 + window.imgui_set_font_global_scale(state["font_scale"]) + window.imgui_end_menu() + + if window.imgui_begin_menu("Help"): + if window.imgui_menu_item("About TensorFrost", shortcut=None, selected=False, enabled=True): + window.imgui_open_popup("about_popup") + window.imgui_end_menu() + + window.imgui_end_main_menu_bar() + + about_visible, _ = window.imgui_begin_popup_modal("about_popup", open=None, flags=0) + if about_visible: + window.imgui_text_wrapped( + "TensorFrost ImGui showcase demonstrating the Python bindings for the Vulkan backend." + ) + window.imgui_spacing() + window.imgui_text("Bindings exercised:") + window.imgui_indent(12.0) + window.imgui_bullet_text("Main menu bar helpers") + window.imgui_bullet_text("Layout, widgets, and style stack APIs") + window.imgui_bullet_text("Background draw list utilities") + window.imgui_unindent(12.0) + window.imgui_spacing() + if window.imgui_button("Close##about"): + window.imgui_close_current_popup() + window.imgui_end_popup() + + visible, _ = window.imgui_begin( + "Control Center", open=None, flags=IMGUI_WINDOW_FLAGS_MENU_BAR + ) + if visible: + if window.imgui_begin_menu_bar(): + if window.imgui_begin_menu("View"): + if window.imgui_menu_item("Reset font scale", shortcut="Ctrl+0", selected=False, enabled=True): + state["font_scale"] = 2.0 + window.imgui_set_font_global_scale(state["font_scale"]) + if window.imgui_menu_item("Show plot", shortcut="P", selected=state["show_plot"], enabled=True): + state["show_plot"] = not state["show_plot"] + window.imgui_end_menu() + + window.imgui_end_menu_bar() + + window.imgui_text(f"Frame time: {dt * 1000.0:5.2f} ms") + window.imgui_text_colored((0.25, 0.85, 0.45, 1.0), f"Average FPS: {fps:5.1f}") + window.imgui_separator() + + controls_visible = window.imgui_begin_child("controls", (320, 280), border=True, flags=0) + if controls_visible: + edited, new_text = window.imgui_input_text("Greeting", state["greeting"], buffer_length=128) + if edited: + state["greeting"] = new_text + + state["animate"] = window.imgui_checkbox("Animate", state["animate"]) + window.imgui_same_line() + state["show_plot"] = window.imgui_checkbox("Show plot", state["show_plot"]) + + state["wave_speed"] = window.imgui_slider_float("Wave speed", state["wave_speed"], 0.1, 5.0) + state["wave_scale"] = window.imgui_slider_float("Wave scale", state["wave_scale"], 0.1, 3.0) + state["sample_count"] = window.imgui_slider_int("History samples", state["sample_count"], 20, sample_history.maxlen) + + window.imgui_spacing() + state["font_scale"] = window.imgui_slider_float("Font scale", state["font_scale"], 0.5, 2.0) + window.imgui_set_font_global_scale(state["font_scale"]) + + window.imgui_spacing() + changed, new_color = window.imgui_color_edit4("Accent color", state["accent"]) + if changed: + state["accent"] = tuple(new_color) + + window.imgui_spacing() + window.imgui_push_style_var_float(IMGUI_STYLEVAR_FRAME_ROUNDING, 8.0) + window.imgui_push_style_color(IMGUI_COL_BUTTON, state["accent"]) + window.imgui_push_style_color(IMGUI_COL_BUTTON_HOVERED, ( + min(1.0, state["accent"][0] + 0.1), + min(1.0, state["accent"][1] + 0.1), + min(1.0, state["accent"][2] + 0.1), + state["accent"][3], + )) + if window.imgui_button("Take snapshot"): + window.imgui_open_popup("snapshot_popup") + window.imgui_pop_style_color(2) + window.imgui_pop_style_var() + + visible_popup, _ = window.imgui_begin_popup_modal("snapshot_popup", open=None, flags=0) + if visible_popup: + window.imgui_text_wrapped( + "Pretend we stored the latest plot sample to disk." + ) + window.imgui_spacing() + if window.imgui_button("Close"): + window.imgui_close_current_popup() + window.imgui_end_popup() + window.imgui_end_child() + + window.imgui_same_line() + + details_visible = window.imgui_begin_child("details", (0, 280), border=True, flags=0) + if details_visible: + window.imgui_text("Details") + window.imgui_separator() + window.imgui_text_wrapped( + "The live sine wave demonstrates sliders, checkboxes, plot widgets, " + "menus, popups, style stacks, and color editing exposed through the Vulkan backend." + ) + + window.imgui_spacing() + window.imgui_indent(10.0) + window.imgui_bullet_text("Toggle themes from the menu bar.") + window.imgui_bullet_text("Drag the sliders to shape the waveform.") + window.imgui_bullet_text("Use the accent color to restyle the snapshot button.") + window.imgui_unindent(10.0) + + window.imgui_spacing() + window.imgui_text_colored(state["accent"], f"Greeting: {state['greeting']}") + + window.imgui_spacing() + window.imgui_push_style_var_vec2(IMGUI_STYLEVAR_WINDOW_PADDING, (12.0, 12.0)) + window.imgui_text("Background text is rendered via the overlay draw list.") + window.imgui_pop_style_var() + window.imgui_end_child() + + window.imgui_separator() + + if state["show_plot"] and history_array.size > 1: + window.imgui_plot_lines( + "Sine wave", history_array, values_offset=0, overlay_text="Normalized", scale_min=-3.0, + scale_max=3.0, graph_size=(0.0, 140.0), stride=4 + ) + + window.imgui_spacing() + window.imgui_text(f"Window size: {width} x {height}") + + window.imgui_end() + + window.imgui_add_background_text( + state["greeting"], + pos=(24.0, 24.0), + color=(state["accent"][0], state["accent"][1], state["accent"][2], 0.12), + ) + + window.present() + + window.close() + + +if __name__ == "__main__": + main() diff --git a/examples/Algorithms/bitonic.ipynb b/examples/legacy/Algorithms/bitonic.ipynb similarity index 100% rename from examples/Algorithms/bitonic.ipynb rename to examples/legacy/Algorithms/bitonic.ipynb diff --git a/examples/Algorithms/custom_operation.py b/examples/legacy/Algorithms/custom_operation.py similarity index 100% rename from examples/Algorithms/custom_operation.py rename to examples/legacy/Algorithms/custom_operation.py diff --git a/examples/Algorithms/fft.ipynb b/examples/legacy/Algorithms/fft.ipynb similarity index 100% rename from examples/Algorithms/fft.ipynb rename to examples/legacy/Algorithms/fft.ipynb diff --git a/examples/Algorithms/fft_group.ipynb b/examples/legacy/Algorithms/fft_group.ipynb similarity index 100% rename from examples/Algorithms/fft_group.ipynb rename to examples/legacy/Algorithms/fft_group.ipynb diff --git a/examples/Algorithms/indexing.py b/examples/legacy/Algorithms/indexing.py similarity index 100% rename from examples/Algorithms/indexing.py rename to examples/legacy/Algorithms/indexing.py diff --git a/examples/Algorithms/indexing_test.ipynb b/examples/legacy/Algorithms/indexing_test.ipynb similarity index 100% rename from examples/Algorithms/indexing_test.ipynb rename to examples/legacy/Algorithms/indexing_test.ipynb diff --git a/examples/Algorithms/kernels.ipynb b/examples/legacy/Algorithms/kernels.ipynb similarity index 100% rename from examples/Algorithms/kernels.ipynb rename to examples/legacy/Algorithms/kernels.ipynb diff --git a/examples/Algorithms/matrix_mul.ipynb b/examples/legacy/Algorithms/matrix_mul.ipynb similarity index 100% rename from examples/Algorithms/matrix_mul.ipynb rename to examples/legacy/Algorithms/matrix_mul.ipynb diff --git a/examples/Algorithms/qr.ipynb b/examples/legacy/Algorithms/qr.ipynb similarity index 100% rename from examples/Algorithms/qr.ipynb rename to examples/legacy/Algorithms/qr.ipynb diff --git a/examples/Algorithms/random_number_generation.ipynb b/examples/legacy/Algorithms/random_number_generation.ipynb similarity index 100% rename from examples/Algorithms/random_number_generation.ipynb rename to examples/legacy/Algorithms/random_number_generation.ipynb diff --git a/examples/Algorithms/random_permutation.ipynb b/examples/legacy/Algorithms/random_permutation.ipynb similarity index 100% rename from examples/Algorithms/random_permutation.ipynb rename to examples/legacy/Algorithms/random_permutation.ipynb diff --git a/examples/Algorithms/reshape_reduction.ipynb b/examples/legacy/Algorithms/reshape_reduction.ipynb similarity index 100% rename from examples/Algorithms/reshape_reduction.ipynb rename to examples/legacy/Algorithms/reshape_reduction.ipynb diff --git a/examples/Algorithms/scan.ipynb b/examples/legacy/Algorithms/scan.ipynb similarity index 100% rename from examples/Algorithms/scan.ipynb rename to examples/legacy/Algorithms/scan.ipynb diff --git a/examples/Algorithms/scatter.py b/examples/legacy/Algorithms/scatter.py similarity index 100% rename from examples/Algorithms/scatter.py rename to examples/legacy/Algorithms/scatter.py diff --git a/examples/Algorithms/sorting.ipynb b/examples/legacy/Algorithms/sorting.ipynb similarity index 100% rename from examples/Algorithms/sorting.ipynb rename to examples/legacy/Algorithms/sorting.ipynb diff --git a/examples/Demos/buddhabrot.gif b/examples/legacy/Demos/buddhabrot.gif similarity index 100% rename from examples/Demos/buddhabrot.gif rename to examples/legacy/Demos/buddhabrot.gif diff --git a/examples/Demos/fluid_sim.gif b/examples/legacy/Demos/fluid_sim.gif similarity index 100% rename from examples/Demos/fluid_sim.gif rename to examples/legacy/Demos/fluid_sim.gif diff --git a/examples/Demos/n_body.gif b/examples/legacy/Demos/n_body.gif similarity index 100% rename from examples/Demos/n_body.gif rename to examples/legacy/Demos/n_body.gif diff --git a/examples/Demos/nca.gif b/examples/legacy/Demos/nca.gif similarity index 100% rename from examples/Demos/nca.gif rename to examples/legacy/Demos/nca.gif diff --git a/examples/Demos/neural_embed.gif b/examples/legacy/Demos/neural_embed.gif similarity index 100% rename from examples/Demos/neural_embed.gif rename to examples/legacy/Demos/neural_embed.gif diff --git a/examples/Demos/path_tracer.gif b/examples/legacy/Demos/path_tracer.gif similarity index 100% rename from examples/Demos/path_tracer.gif rename to examples/legacy/Demos/path_tracer.gif diff --git a/examples/Demos/sin_gordon.gif b/examples/legacy/Demos/sin_gordon.gif similarity index 100% rename from examples/Demos/sin_gordon.gif rename to examples/legacy/Demos/sin_gordon.gif diff --git a/examples/GUI/buddhabrot.py b/examples/legacy/GUI/buddhabrot.py similarity index 100% rename from examples/GUI/buddhabrot.py rename to examples/legacy/GUI/buddhabrot.py diff --git a/examples/GUI/garden_smol.hdr b/examples/legacy/GUI/garden_smol.hdr similarity index 100% rename from examples/GUI/garden_smol.hdr rename to examples/legacy/GUI/garden_smol.hdr diff --git a/examples/GUI/image_matcher.py b/examples/legacy/GUI/image_matcher.py similarity index 100% rename from examples/GUI/image_matcher.py rename to examples/legacy/GUI/image_matcher.py diff --git a/examples/GUI/interactive_path_tracer.py b/examples/legacy/GUI/interactive_path_tracer.py similarity index 100% rename from examples/GUI/interactive_path_tracer.py rename to examples/legacy/GUI/interactive_path_tracer.py diff --git a/examples/ML/MNIST/MNIST.ipynb b/examples/legacy/ML/MNIST/MNIST.ipynb similarity index 100% rename from examples/ML/MNIST/MNIST.ipynb rename to examples/legacy/ML/MNIST/MNIST.ipynb diff --git a/examples/ML/MNIST/loadMNIST.py b/examples/legacy/ML/MNIST/loadMNIST.py similarity index 100% rename from examples/ML/MNIST/loadMNIST.py rename to examples/legacy/ML/MNIST/loadMNIST.py diff --git a/examples/ML/MNIST/module.py b/examples/legacy/ML/MNIST/module.py similarity index 100% rename from examples/ML/MNIST/module.py rename to examples/legacy/ML/MNIST/module.py diff --git a/examples/ML/MNIST/pytorch.py b/examples/legacy/ML/MNIST/pytorch.py similarity index 100% rename from examples/ML/MNIST/pytorch.py rename to examples/legacy/ML/MNIST/pytorch.py diff --git a/examples/ML/NCA/bugcat.png b/examples/legacy/ML/NCA/bugcat.png similarity index 100% rename from examples/ML/NCA/bugcat.png rename to examples/legacy/ML/NCA/bugcat.png diff --git a/examples/ML/NCA/catthink.png b/examples/legacy/ML/NCA/catthink.png similarity index 100% rename from examples/ML/NCA/catthink.png rename to examples/legacy/ML/NCA/catthink.png diff --git a/examples/ML/NCA/inference.py b/examples/legacy/ML/NCA/inference.py similarity index 100% rename from examples/ML/NCA/inference.py rename to examples/legacy/ML/NCA/inference.py diff --git a/examples/ML/NCA/nca.py b/examples/legacy/ML/NCA/nca.py similarity index 100% rename from examples/ML/NCA/nca.py rename to examples/legacy/ML/NCA/nca.py diff --git a/examples/ML/NCA/shadertoy.py b/examples/legacy/ML/NCA/shadertoy.py similarity index 100% rename from examples/ML/NCA/shadertoy.py rename to examples/legacy/ML/NCA/shadertoy.py diff --git a/examples/ML/NCA/train.py b/examples/legacy/ML/NCA/train.py similarity index 100% rename from examples/ML/NCA/train.py rename to examples/legacy/ML/NCA/train.py diff --git a/examples/ML/VMC/atom.py b/examples/legacy/ML/VMC/atom.py similarity index 100% rename from examples/ML/VMC/atom.py rename to examples/legacy/ML/VMC/atom.py diff --git a/examples/ML/VMC/camera.py b/examples/legacy/ML/VMC/camera.py similarity index 100% rename from examples/ML/VMC/camera.py rename to examples/legacy/ML/VMC/camera.py diff --git a/examples/ML/VMC/logdet.py b/examples/legacy/ML/VMC/logdet.py similarity index 100% rename from examples/ML/VMC/logdet.py rename to examples/legacy/ML/VMC/logdet.py diff --git a/examples/ML/VMC/logdet_test.py b/examples/legacy/ML/VMC/logdet_test.py similarity index 100% rename from examples/ML/VMC/logdet_test.py rename to examples/legacy/ML/VMC/logdet_test.py diff --git a/examples/ML/VMC/molecules.py b/examples/legacy/ML/VMC/molecules.py similarity index 100% rename from examples/ML/VMC/molecules.py rename to examples/legacy/ML/VMC/molecules.py diff --git a/examples/ML/VMC/utils.py b/examples/legacy/ML/VMC/utils.py similarity index 100% rename from examples/ML/VMC/utils.py rename to examples/legacy/ML/VMC/utils.py diff --git a/examples/ML/VMC/vec3.py b/examples/legacy/ML/VMC/vec3.py similarity index 100% rename from examples/ML/VMC/vec3.py rename to examples/legacy/ML/VMC/vec3.py diff --git a/examples/ML/VMC/visualizer.py b/examples/legacy/ML/VMC/visualizer.py similarity index 100% rename from examples/ML/VMC/visualizer.py rename to examples/legacy/ML/VMC/visualizer.py diff --git a/examples/ML/VMC/visualizer_test.py b/examples/legacy/ML/VMC/visualizer_test.py similarity index 100% rename from examples/ML/VMC/visualizer_test.py rename to examples/legacy/ML/VMC/visualizer_test.py diff --git a/examples/ML/VMC/vmc.py b/examples/legacy/ML/VMC/vmc.py similarity index 100% rename from examples/ML/VMC/vmc.py rename to examples/legacy/ML/VMC/vmc.py diff --git a/examples/Rendering/blur.ipynb b/examples/legacy/Rendering/blur.ipynb similarity index 100% rename from examples/Rendering/blur.ipynb rename to examples/legacy/Rendering/blur.ipynb diff --git a/examples/Rendering/buddhabrot.ipynb b/examples/legacy/Rendering/buddhabrot.ipynb similarity index 100% rename from examples/Rendering/buddhabrot.ipynb rename to examples/legacy/Rendering/buddhabrot.ipynb diff --git a/examples/Rendering/convolution.py b/examples/legacy/Rendering/convolution.py similarity index 100% rename from examples/Rendering/convolution.py rename to examples/legacy/Rendering/convolution.py diff --git a/examples/Rendering/fft2d.ipynb b/examples/legacy/Rendering/fft2d.ipynb similarity index 100% rename from examples/Rendering/fft2d.ipynb rename to examples/legacy/Rendering/fft2d.ipynb diff --git a/examples/Rendering/fft3d.py b/examples/legacy/Rendering/fft3d.py similarity index 100% rename from examples/Rendering/fft3d.py rename to examples/legacy/Rendering/fft3d.py diff --git a/examples/Rendering/gaussian_grid.ipynb b/examples/legacy/Rendering/gaussian_grid.ipynb similarity index 100% rename from examples/Rendering/gaussian_grid.ipynb rename to examples/legacy/Rendering/gaussian_grid.ipynb diff --git a/examples/Rendering/mandelbrot.ipynb b/examples/legacy/Rendering/mandelbrot.ipynb similarity index 100% rename from examples/Rendering/mandelbrot.ipynb rename to examples/legacy/Rendering/mandelbrot.ipynb diff --git a/examples/Rendering/neural_embed.ipynb b/examples/legacy/Rendering/neural_embed.ipynb similarity index 100% rename from examples/Rendering/neural_embed.ipynb rename to examples/legacy/Rendering/neural_embed.ipynb diff --git a/examples/Rendering/neural_embed2.ipynb b/examples/legacy/Rendering/neural_embed2.ipynb similarity index 100% rename from examples/Rendering/neural_embed2.ipynb rename to examples/legacy/Rendering/neural_embed2.ipynb diff --git a/examples/Rendering/ray_marcher.ipynb b/examples/legacy/Rendering/ray_marcher.ipynb similarity index 100% rename from examples/Rendering/ray_marcher.ipynb rename to examples/legacy/Rendering/ray_marcher.ipynb diff --git a/examples/Rendering/sphere_tracer.ipynb b/examples/legacy/Rendering/sphere_tracer.ipynb similarity index 100% rename from examples/Rendering/sphere_tracer.ipynb rename to examples/legacy/Rendering/sphere_tracer.ipynb diff --git a/examples/Rendering/test.png b/examples/legacy/Rendering/test.png similarity index 100% rename from examples/Rendering/test.png rename to examples/legacy/Rendering/test.png diff --git a/examples/Simulation/fluid_simulation.ipynb b/examples/legacy/Simulation/fluid_simulation.ipynb similarity index 100% rename from examples/Simulation/fluid_simulation.ipynb rename to examples/legacy/Simulation/fluid_simulation.ipynb diff --git a/examples/Simulation/n-body-benchmark.py b/examples/legacy/Simulation/n-body-benchmark.py similarity index 100% rename from examples/Simulation/n-body-benchmark.py rename to examples/legacy/Simulation/n-body-benchmark.py diff --git a/examples/Simulation/n-body.ipynb b/examples/legacy/Simulation/n-body.ipynb similarity index 100% rename from examples/Simulation/n-body.ipynb rename to examples/legacy/Simulation/n-body.ipynb diff --git a/examples/Simulation/poission.py b/examples/legacy/Simulation/poission.py similarity index 100% rename from examples/Simulation/poission.py rename to examples/legacy/Simulation/poission.py diff --git a/examples/Simulation/wave_simulation.ipynb b/examples/legacy/Simulation/wave_simulation.ipynb similarity index 100% rename from examples/Simulation/wave_simulation.ipynb rename to examples/legacy/Simulation/wave_simulation.ipynb diff --git a/examples/radix_sort/__main__.py b/examples/radix_sort/__main__.py new file mode 100644 index 00000000..60c321ec --- /dev/null +++ b/examples/radix_sort/__main__.py @@ -0,0 +1,290 @@ +from __future__ import annotations + +import argparse +import math +import time +from collections import deque + +import numpy as np +import TensorFrost as tf + +try: + from .sort import HistogramRadixSort +except ImportError: + import sys + from pathlib import Path + + _CURRENT_DIR = Path(__file__).resolve().parent + if str(_CURRENT_DIR) not in sys.path: + sys.path.insert(0, str(_CURRENT_DIR)) + + from sort import HistogramRadixSort + + +_STAGE_NAMES = ( + "map_to_uint", + "histogram", + "unpack", + "prefix_local", + "prefix_blocks", + "prefix_accum", + "bucket_scan", + "scatter", + "map_from_uint", +) + + +def _select_backend() -> None: + if hasattr(tf, "initialize"): + backend = getattr(tf, "vulkan", None) + if backend is None: + raise RuntimeError("TensorFrost Vulkan backend is unavailable on this build") + tf.initialize(backend) + + +def _format_rate(value: float, unit: str) -> str: + if value <= 0.0 or not math.isfinite(value): + return f"0 {unit}" + + prefixes = ("", "K", "M", "G", "T", "P") + magnitude = 0 + while value >= 1000.0 and magnitude < len(prefixes) - 1: + value /= 1000.0 + magnitude += 1 + + return f"{value:6.2f} {prefixes[magnitude]}{unit}" + + +def _format_stage_name(name: str) -> str: + return name.replace("_", " ").title() + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Interactive histogram radix sort demo running continuously on the Vulkan backend.", + ) + parser.add_argument("--size", type=int, default=1 << 20, help="Number of key/value pairs to sort") + parser.add_argument("--bits", type=int, default=6, help="Bits processed per pass") + parser.add_argument("--window-width", type=int, default=960, help="Window width in pixels") + parser.add_argument("--window-height", type=int, default=540, help="Window height in pixels") + parser.add_argument("--font-scale", type=float, default=1.6, help="Global ImGui font scale") + parser.add_argument("--history", type=int, default=240, help="Number of samples stored for metrics history") + parser.add_argument("--seed", type=int, default=1337, help="Seed for the random data generator") + args = parser.parse_args() + + _select_backend() + + count = max(0, int(args.size)) + bits_per_pass = max(1, int(args.bits)) + window_width = max(320, int(args.window_width)) + window_height = max(240, int(args.window_height)) + history_length = max(1, int(args.history)) + font_scale = max(0.1, float(args.font_scale)) if args.font_scale > 0 else 0.0 + + rng = np.random.default_rng(int(args.seed)) + keys = ( + rng.standard_normal(count, dtype=np.float32) + if count + else np.empty(0, dtype=np.float32) + ) + values = ( + rng.integers(0, 1 << 31, size=count, dtype=np.uint32) + if count + else np.empty(0, dtype=np.uint32) + ) + + # GPU-side validation: we'll run a one-time kernel check that counts out-of-order adjacent pairs. + # This avoids reading back entire arrays for CPU comparison. + reference_keys = None + reference_values = None + + frame_times = deque(maxlen=history_length) + sort_times = deque(maxlen=history_length) + stage_totals_overall = {name: 0.0 for name in _STAGE_NAMES} + + validated = False + validation_ok = False + validation_message = "" + + sort_count = 0 + total_kernel_time = 0.0 + + sorter = HistogramRadixSort(bits_per_pass=bits_per_pass) + + window_title = f"TensorFrost Radix Sort ({count:,} elements)" + window = tf.createWindow(window_width, window_height, window_title) + + if font_scale > 0.0 and window is not None: + window.imgui_scale_all_sizes(font_scale) + window.imgui_set_font_global_scale(font_scale) + + start_time = time.perf_counter() + last_frame_time = start_time + + while window is not None and window.isOpen(): + now = time.perf_counter() + dt = now - last_frame_time + last_frame_time = now + + if dt > 0.0: + frame_times.append(dt) + frame_time_total = sum(frame_times) + fps = (len(frame_times) / frame_time_total) if frame_times and frame_time_total > 0.0 else 0.0 + + # Perform sort; on the first run also validate on GPU and read back buffers for NumPy comparison. + do_validate = not validated + return_arrays = do_validate + _keys_out, _vals_out = sorter.sort( + keys, + values, + collect_stage_timings=True, + validate=do_validate, + return_arrays=return_arrays, + ) + stage_timings = sorter.last_stage_timings or {} + # Separate total_pass (overall) from per-stage summed time + total_pass_time = float(stage_timings.get("total_pass", 0.0)) + stage_sum_time = float(sum(v for k, v in stage_timings.items() if k != "total_pass")) + kernel_time = total_pass_time if total_pass_time > 0.0 else stage_sum_time + + if sort_times.maxlen == len(sort_times): + sort_times.popleft() + sort_times.append(kernel_time) + + sort_count += 1 + total_kernel_time += kernel_time + for name in _STAGE_NAMES: + stage_totals_overall[name] += stage_timings.get(name, 0.0) + + if do_validate: + errors = int(getattr(sorter, "last_validation_errors", 0) or 0) + gpu_validation_ok = (errors == 0) + + cpu_validation_ok = True + cpu_messages = [] + if return_arrays: + if count: + reference_indices = np.argsort(keys, kind="stable") + reference_keys = keys[reference_indices] + reference_values = values[reference_indices] if values is not None else None + else: + reference_keys = np.copy(keys) + reference_values = np.copy(values) if values is not None else None + + if not np.array_equal(_keys_out, reference_keys): + cpu_validation_ok = False + cpu_messages.append("key mismatch") + + if values is not None and _vals_out is not None and not np.array_equal(_vals_out, reference_values): + cpu_validation_ok = False + cpu_messages.append("value mismatch") + + validation_ok = gpu_validation_ok and cpu_validation_ok + if validation_ok: + validation_message = "GPU + NumPy validation passed." + else: + failure_reasons = [] + if not gpu_validation_ok: + failure_reasons.append(f"GPU reported {errors} out-of-order pairs") + if not cpu_validation_ok: + reason = ", ".join(cpu_messages) if cpu_messages else "NumPy comparison failed" + failure_reasons.append(reason) + validation_message = "Validation failed: " + "; ".join(failure_reasons) + + validated = True + + # No large array readback performed when return_arrays=False. + + window_avg_sort = (sum(sort_times) / len(sort_times)) if sort_times else 0.0 + last_sort_ms = kernel_time * 1000.0 + avg_sort_ms = window_avg_sort * 1000.0 + + last_sort_rate = (1.0 / kernel_time) if kernel_time > 0.0 else 0.0 + window_sort_rate = (1.0 / window_avg_sort) if window_avg_sort > 0.0 else 0.0 + overall_avg_sort = (total_kernel_time / sort_count) if sort_count else 0.0 + overall_sort_rate = (1.0 / overall_avg_sort) if overall_avg_sort > 0.0 else 0.0 + + last_elements_per_sec = (count / kernel_time) if count and kernel_time > 0.0 else 0.0 + window_elements_per_sec = (count / window_avg_sort) if count and window_avg_sort > 0.0 else 0.0 + + history_data = np.array(sort_times, dtype=np.float32) * 1000.0 + total_elapsed = now - start_time + + visible, _ = window.imgui_begin("Radix Sort Performance", open=None, flags=0) + if visible: + window.imgui_text(f"Elements: {count:,}") + window.imgui_text(f"Bits per pass: {bits_per_pass}") + window.imgui_separator() + window.imgui_text(f"Frame time: {dt * 1000.0:7.3f} ms") + window.imgui_text(f"Average FPS: {fps:6.1f}") + window.imgui_separator() + + if sort_count: + window.imgui_text(f"Last kernel time: {last_sort_ms:7.3f} ms") + if total_pass_time > 0.0: + window.imgui_text(f"Total pass time (last): {total_pass_time * 1000.0:7.3f} ms") + if window_avg_sort > 0.0: + window.imgui_text(f"Window avg kernel: {avg_sort_ms:7.3f} ms") + if overall_avg_sort > 0.0: + window.imgui_text(f"Overall avg kernel: {overall_avg_sort * 1000.0:7.3f} ms") + + window.imgui_spacing() + window.imgui_text(f"Sorts/sec (last): {last_sort_rate:8.2f}") + window.imgui_text(f"Sorts/sec (window): {window_sort_rate:8.2f}") + window.imgui_text(f"Sorts/sec (overall): {overall_sort_rate:8.2f}") + + if count: + window.imgui_spacing() + window.imgui_text( + f"Elements/sec (last): {_format_rate(last_elements_per_sec, ' elements/s')}" + ) + window.imgui_text( + f"Elements/sec (window): {_format_rate(window_elements_per_sec, ' elements/s')}" + ) + + window.imgui_spacing() + window.imgui_text(f"Total sorts: {sort_count:,}") + else: + if count: + window.imgui_text("Waiting for first GPU sort...") + else: + window.imgui_text("No elements to sort.") + + if validated: + color = (0.20, 0.80, 0.40, 1.0) if validation_ok else (0.92, 0.35, 0.32, 1.0) + window.imgui_text_colored(color, validation_message) + elif count: + window.imgui_text("Validating against CPU reference...") + + if history_data.size: + max_plot = float(np.max(history_data)) if history_data.size else 1.0 + max_plot = max(1.0, max_plot * 1.1) + window.imgui_plot_lines( + "Kernel time history (ms)", + history_data, + values_offset=0, + overlay_text=f"{history_data.size} sample history", + scale_min=0.0, + scale_max=max_plot, + graph_size=(0.0, 140.0), + stride=4, + ) + + if stage_timings: + window.imgui_separator() + window.imgui_text("Kernel stages (last run):") + for name, duration in sorted(stage_timings.items(), key=lambda item: item[1], reverse=True): + window.imgui_text(f"{_format_stage_name(name)}: {duration * 1000.0:7.3f} ms") + + window.imgui_spacing() + window.imgui_text("Kernel stages (overall avg):") + for name in sorted(stage_totals_overall.keys()): + avg_duration = (stage_totals_overall[name] / sort_count) if sort_count else 0.0 + window.imgui_text(f"{_format_stage_name(name)}: {avg_duration * 1000.0:7.3f} ms") + + window.imgui_end() + window.present() + + +if __name__ == "__main__": + main() diff --git a/examples/radix_sort/shaders/__init__.py b/examples/radix_sort/shaders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/radix_sort/shaders/bucket_scan.slang b/examples/radix_sort/shaders/bucket_scan.slang new file mode 100644 index 00000000..cf57aa03 --- /dev/null +++ b/examples/radix_sort/shaders/bucket_scan.slang @@ -0,0 +1,47 @@ +struct SortParams { + uint elementCount; + uint histogramSize; + uint shift; + uint mask; + uint numGroups; + uint blockSize; + uint blockCount; + uint hasValues; +}; + +[[vk::push_constant]] +SortParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer GroupPrefix : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer BucketScan : register(u1, space0); + +[numthreads(64, 1, 1)] +void csBucketScan(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + if (dispatchThreadID.x != 0) + return; + + uint histogramSize = gParams.histogramSize; + uint numGroups = gParams.numGroups; + + if (histogramSize == 0) + return; + + if (numGroups == 0) + { + for (uint bit = 0; bit < histogramSize; ++bit) + { + BucketScan[bit] = 0; + } + return; + } + + uint base = (numGroups - 1u) * histogramSize; + uint sum = 0; + for (uint bit = 0; bit < histogramSize; ++bit) + { + uint total = GroupPrefix[base + bit]; + sum += total; + BucketScan[bit] = sum; + } +} diff --git a/examples/radix_sort/shaders/histogram.slang b/examples/radix_sort/shaders/histogram.slang new file mode 100644 index 00000000..393b9c1d --- /dev/null +++ b/examples/radix_sort/shaders/histogram.slang @@ -0,0 +1,72 @@ +static const uint GROUP_SIZE = TF_GROUP_SIZE; +static const uint QUARTER_SIZE = GROUP_SIZE / 4; +static const uint MAX_HIST_SIZE = 256u; + +uint packedCount(uint histogramSize) +{ + return (histogramSize + 3u) >> 2; +} + +struct SortParams { + uint elementCount; + uint histogramSize; + uint shift; + uint mask; + uint numGroups; + uint blockSize; + uint blockCount; + uint hasValues; +}; + +[[vk::push_constant]] +SortParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer KeysIn : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer PackedHistogram : register(u1, space0); + +groupshared uint sPacked[MAX_HIST_SIZE / 4]; + +[numthreads(GROUP_SIZE, 1, 1)] +void csHistogram(uint3 groupID : SV_GroupID, uint3 localID : SV_GroupThreadID) +{ + uint elementCount = gParams.elementCount; + uint histogramSize = gParams.histogramSize; + uint shift = gParams.shift; + uint mask = gParams.mask; + uint numGroups = gParams.numGroups; + + if (histogramSize > MAX_HIST_SIZE) + return; + + uint group = groupID.x; + if (group >= numGroups) + return; + + uint packedCountLocal = packedCount(histogramSize); + uint lane = localID.x; + + for (uint idx = lane; idx < packedCountLocal; idx += GROUP_SIZE) + { + sPacked[idx] = 0; + } + GroupMemoryBarrierWithGroupSync(); + + uint globalIndex = group * GROUP_SIZE + lane; + if (globalIndex < elementCount) + { + uint key = KeysIn[globalIndex]; + uint bit = (key >> shift) & mask; + uint slot = bit >> 2; + if (slot < packedCountLocal) + { + uint laneShift = (bit & 3u) * 8u; + InterlockedAdd(sPacked[slot], 1u << laneShift); + } + } + GroupMemoryBarrierWithGroupSync(); + + for (uint idx = lane; idx < packedCountLocal; idx += GROUP_SIZE) + { + PackedHistogram[group * packedCountLocal + idx] = sPacked[idx]; + } +} diff --git a/examples/radix_sort/shaders/map_from_uint.slang b/examples/radix_sort/shaders/map_from_uint.slang new file mode 100644 index 00000000..93491154 --- /dev/null +++ b/examples/radix_sort/shaders/map_from_uint.slang @@ -0,0 +1,41 @@ +static const uint TYPE_UINT = 0u; +static const uint TYPE_INT = 1u; +static const uint TYPE_FLOAT = 2u; +static const uint SIGN_BIT = 0x80000000u; +static const uint FULL_MASK = 0xFFFFFFFFu; +static const uint GROUP_SIZE = TF_GROUP_SIZE; + +struct MapParams { + uint count; + uint typeCode; +}; + +[[vk::push_constant]] +MapParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer Input : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer Output : register(u1, space0); + +[numthreads(GROUP_SIZE, 1, 1)] +void csMapFromUint(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint index = dispatchThreadID.x; + uint count = gParams.count; + if (index >= count) + return; + + uint typeCode = gParams.typeCode; + uint value = Input[index]; + + if (typeCode == TYPE_INT) + { + value ^= SIGN_BIT; + } + else if (typeCode == TYPE_FLOAT) + { + uint mask = ((value >> 31) == 0u) ? FULL_MASK : SIGN_BIT; + value ^= mask; + } + + Output[index] = value; +} diff --git a/examples/radix_sort/shaders/map_to_uint.slang b/examples/radix_sort/shaders/map_to_uint.slang new file mode 100644 index 00000000..9cf4863a --- /dev/null +++ b/examples/radix_sort/shaders/map_to_uint.slang @@ -0,0 +1,41 @@ +static const uint TYPE_UINT = 0u; +static const uint TYPE_INT = 1u; +static const uint TYPE_FLOAT = 2u; +static const uint SIGN_BIT = 0x80000000u; +static const uint FULL_MASK = 0xFFFFFFFFu; +static const uint GROUP_SIZE = TF_GROUP_SIZE; + +struct MapParams { + uint count; + uint typeCode; +}; + +[[vk::push_constant]] +MapParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer Input : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer Output : register(u1, space0); + +[numthreads(GROUP_SIZE, 1, 1)] +void csMapToUint(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint index = dispatchThreadID.x; + uint count = gParams.count; + if (index >= count) + return; + + uint typeCode = gParams.typeCode; + uint value = Input[index]; + + if (typeCode == TYPE_INT) + { + value ^= SIGN_BIT; + } + else if (typeCode == TYPE_FLOAT) + { + uint mask = ((value >> 31) == 1u) ? FULL_MASK : SIGN_BIT; + value ^= mask; + } + + Output[index] = value; +} diff --git a/examples/radix_sort/shaders/prefix_accum.slang b/examples/radix_sort/shaders/prefix_accum.slang new file mode 100644 index 00000000..5cd2ea3f --- /dev/null +++ b/examples/radix_sort/shaders/prefix_accum.slang @@ -0,0 +1,45 @@ +struct SortParams { + uint elementCount; + uint histogramSize; + uint shift; + uint mask; + uint numGroups; + uint blockSize; + uint blockCount; + uint hasValues; +}; + +[[vk::push_constant]] +SortParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer BlockPrefix : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer GroupPrefix : register(u1, space0); + +[numthreads(64, 1, 1)] +void csPrefixAccumulate(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint histogramSize = gParams.histogramSize; + uint numGroups = gParams.numGroups; + uint blockSize = gParams.blockSize; + uint blockCount = gParams.blockCount; + + uint totalThreads = blockCount * histogramSize; + uint index = dispatchThreadID.x; + if (index >= totalThreads) + return; + + uint block = index / histogramSize; + uint bit = index % histogramSize; + + uint startGroup = block * blockSize; + if (startGroup >= numGroups) + return; + + uint endGroup = min(startGroup + blockSize, numGroups); + uint prefix = BlockPrefix[index]; + for (uint g = startGroup; g < endGroup; ++g) + { + uint idx = g * histogramSize + bit; + GroupPrefix[idx] += prefix; + } +} diff --git a/examples/radix_sort/shaders/prefix_block.slang b/examples/radix_sort/shaders/prefix_block.slang new file mode 100644 index 00000000..b4d714ee --- /dev/null +++ b/examples/radix_sort/shaders/prefix_block.slang @@ -0,0 +1,35 @@ +struct SortParams { + uint elementCount; + uint histogramSize; + uint shift; + uint mask; + uint numGroups; + uint blockSize; + uint blockCount; + uint hasValues; +}; + +[[vk::push_constant]] +SortParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer BlockTotals : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer BlockPrefix : register(u1, space0); + +[numthreads(64, 1, 1)] +void csPrefixBlocks(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint histogramSize = gParams.histogramSize; + uint blockCount = gParams.blockCount; + uint bucket = dispatchThreadID.x; + if (bucket >= histogramSize) + return; + + uint sum = 0; + for (uint block = 0; block < blockCount; ++block) + { + uint idx = block * histogramSize + bucket; + uint total = BlockTotals[idx]; + BlockPrefix[idx] = sum; + sum += total; + } +} diff --git a/examples/radix_sort/shaders/prefix_local.slang b/examples/radix_sort/shaders/prefix_local.slang new file mode 100644 index 00000000..1f87f31a --- /dev/null +++ b/examples/radix_sort/shaders/prefix_local.slang @@ -0,0 +1,51 @@ +struct SortParams { + uint elementCount; + uint histogramSize; + uint shift; + uint mask; + uint numGroups; + uint blockSize; + uint blockCount; + uint hasValues; +}; + +[[vk::push_constant]] +SortParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer GroupHistogram : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer GroupPrefix : register(u1, space0); +[[vk::binding(2,0)]] RWStructuredBuffer BlockTotals : register(u2, space0); + +[numthreads(64, 1, 1)] +void csPrefixLocal(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint histogramSize = gParams.histogramSize; + uint numGroups = gParams.numGroups; + uint blockSize = gParams.blockSize; + uint blockCount = gParams.blockCount; + + uint totalThreads = blockCount * histogramSize; + uint index = dispatchThreadID.x; + if (index >= totalThreads) + return; + + uint block = index / histogramSize; + uint bit = index % histogramSize; + + uint startGroup = block * blockSize; + if (startGroup >= numGroups) + { + BlockTotals[index] = 0; + return; + } + + uint endGroup = min(startGroup + blockSize, numGroups); + uint sum = 0; + for (uint g = startGroup; g < endGroup; ++g) + { + uint idx = g * histogramSize + bit; + sum += GroupHistogram[idx]; + GroupPrefix[idx] = sum; + } + BlockTotals[index] = sum; +} diff --git a/examples/radix_sort/shaders/scatter.slang b/examples/radix_sort/shaders/scatter.slang new file mode 100644 index 00000000..dda8d909 --- /dev/null +++ b/examples/radix_sort/shaders/scatter.slang @@ -0,0 +1,105 @@ +static const uint GROUP_SIZE = TF_GROUP_SIZE; +static const uint QUARTER_SIZE = GROUP_SIZE / 4; +static const uint HISTOGRAM_SIZE = TF_HISTOGRAM_SIZE; + +struct SortParams { + uint elementCount; + uint histogramSize; + uint shift; + uint mask; + uint numGroups; + uint blockSize; + uint blockCount; + uint hasValues; +}; + +[[vk::push_constant]] +SortParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer KeysIn : register(t0, space0); +[[vk::binding(1,0)]] StructuredBuffer ValuesIn : register(t1, space0); +[[vk::binding(2,0)]] StructuredBuffer GroupPrefix : register(t2, space0); +[[vk::binding(3,0)]] StructuredBuffer BucketScan : register(t3, space0); +[[vk::binding(4,0)]] RWStructuredBuffer KeysOut : register(u4, space0); +[[vk::binding(5,0)]] RWStructuredBuffer ValuesOut : register(u5, space0); + +groupshared uint tempBits[GROUP_SIZE]; +groupshared uint halfCount[HISTOGRAM_SIZE]; + +[numthreads(GROUP_SIZE, 1, 1)] +void csScatter(uint3 groupID : SV_GroupID, uint3 localID : SV_GroupThreadID) +{ + uint elementCount = gParams.elementCount; + uint shift = gParams.shift; + uint mask = gParams.mask; + uint numGroups = gParams.numGroups; + uint hasValues = gParams.hasValues; + + uint group = groupID.x; + if (group >= numGroups) + return; + + uint lane = localID.x; + + [unroll] for (uint idx = lane; idx < HISTOGRAM_SIZE; idx += GROUP_SIZE) + { + if (idx < HISTOGRAM_SIZE) halfCount[idx] = 0; + } + + uint globalIndex = group * GROUP_SIZE + lane; + bool active = (globalIndex < elementCount); + uint bit = 0; + uint key = 0; + uint value = 0; + + if (active) + { + key = KeysIn[globalIndex]; + bit = (key >> shift) & mask; + tempBits[lane] = bit; + if (hasValues != 0) + value = ValuesIn[globalIndex]; + } + else + { + tempBits[lane] = 0; + } + GroupMemoryBarrierWithGroupSync(); + + uint quarterIndex = lane / QUARTER_SIZE; + if (active && quarterIndex < 3) + { + uint inc = 0; + if (quarterIndex < 1) inc |= 1u; + if (quarterIndex < 2) inc |= (1u << 8); + if (quarterIndex < 3) inc |= (1u << 16); + InterlockedAdd(halfCount[bit], inc); + } + GroupMemoryBarrierWithGroupSync(); + + if (active) { + uint prevBucket = (bit == 0) ? 0 : BucketScan[bit - 1u]; + uint prevGroup = (group == 0) ? 0 : GroupPrefix[(group - 1u) * HISTOGRAM_SIZE + bit]; + + uint quarterOffset = 0; + if (quarterIndex > 0) + { + uint packed = halfCount[bit]; + uint shiftBytes = (quarterIndex - 1u) * 8u; + quarterOffset = (packed >> shiftBytes) & 0xFFu; + } + + uint begin = quarterIndex * QUARTER_SIZE; + uint localCount = 0; + [unroll] for (uint t = begin; t < lane; ++t) + { + localCount += (tempBits[t] == bit) ? 1u : 0u; + } + + uint totalOffset = prevBucket + prevGroup + quarterOffset + localCount; + KeysOut[totalOffset] = key; + + if (hasValues != 0) + ValuesOut[totalOffset] = value; + } +} diff --git a/examples/radix_sort/shaders/unpack.slang b/examples/radix_sort/shaders/unpack.slang new file mode 100644 index 00000000..49c0ea83 --- /dev/null +++ b/examples/radix_sort/shaders/unpack.slang @@ -0,0 +1,43 @@ +uint packedCount(uint histogramSize) +{ + return (histogramSize + 3u) >> 2; +} + +struct SortParams { + uint elementCount; + uint histogramSize; + uint shift; + uint mask; + uint numGroups; + uint blockSize; + uint blockCount; + uint hasValues; +}; + +[[vk::push_constant]] +SortParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer PackedHistogram : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer GroupHistogram : register(u1, space0); + +[numthreads(64, 1, 1)] +void csUnpack(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint histogramSize = gParams.histogramSize; + uint numGroups = gParams.numGroups; + uint packedCountLocal = packedCount(histogramSize); + + uint total = numGroups * histogramSize; + uint index = dispatchThreadID.x; + if (index >= total) + return; + + uint group = index / histogramSize; + uint bit = index % histogramSize; + + uint packedIndex = group * packedCountLocal + (bit >> 2); + uint packed = PackedHistogram[packedIndex]; + uint shift = (bit & 3u) * 8u; + uint count = (packed >> shift) & 0xFFu; + GroupHistogram[index] = count; +} diff --git a/examples/radix_sort/shaders/validate_sorted.slang b/examples/radix_sort/shaders/validate_sorted.slang new file mode 100644 index 00000000..e5edc419 --- /dev/null +++ b/examples/radix_sort/shaders/validate_sorted.slang @@ -0,0 +1,51 @@ +static const uint TYPE_UINT = 0u; +static const uint TYPE_INT = 1u; +static const uint TYPE_FLOAT = 2u; +static const uint SIGN_BIT = 0x80000000u; +static const uint FULL_MASK = 0xFFFFFFFFu; +static const uint GROUP_SIZE = TF_GROUP_SIZE; + +struct ValidateParams { + uint elementCount; + uint typeCode; +}; + +[[vk::push_constant]] +ValidateParams gParams; + +[[vk::binding(0,0)]] StructuredBuffer Keys : register(t0, space0); // original key bit patterns +[[vk::binding(1,0)]] RWStructuredBuffer Errors : register(u1, space0); // single uint atomic counter + +// Map a key to sortable uint (same transform as map_to_uint.slang) +uint mapKey(uint raw, uint typeCode) +{ + if (typeCode == TYPE_INT) + { + raw ^= SIGN_BIT; + } + else if (typeCode == TYPE_FLOAT) + { + uint mask = ((raw >> 31) == 1u) ? FULL_MASK : SIGN_BIT; + raw ^= mask; + } + return raw; +} + +[numthreads(GROUP_SIZE, 1, 1)] +void csValidate(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint i = dispatchThreadID.x; + uint n = gParams.elementCount; + if (n < 2 || i >= n - 1) + return; + + uint typeCode = gParams.typeCode; + + uint a = mapKey(Keys[i], typeCode); + uint b = mapKey(Keys[i + 1u], typeCode); + + if (a > b) + { + InterlockedAdd(Errors[0], 1u); + } +} diff --git a/examples/radix_sort/sort.py b/examples/radix_sort/sort.py new file mode 100644 index 00000000..08ed3d85 --- /dev/null +++ b/examples/radix_sort/sort.py @@ -0,0 +1,457 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Optional, Tuple +import time + +import numpy as np + +import TensorFrost as tf + +__all__ = ["HistogramRadixSort", "radix_sort"] + +_TYPE_CODES: Dict[str, np.uint32] = { + "uint": np.uint32(0), + "int": np.uint32(1), + "float": np.uint32(2), +} + + +def _dispatch_groups(work_items: int, threads_per_group: int) -> int: + if work_items <= 0: + return 0 + return (work_items + threads_per_group - 1) // threads_per_group + + +def _prepare_keys(keys: np.ndarray) -> Tuple[np.ndarray, np.dtype, str]: + array = np.asarray(keys) + if array.ndim != 1: + raise ValueError("radix_sort expects a 1D array of keys") + + dtype = array.dtype + if dtype == np.uint32: + return array, dtype, "uint" + + if dtype == np.int32: + return array, dtype, "int" + + if dtype == np.float32: + return array, dtype, "float" + + raise TypeError(f"Unsupported key dtype {dtype}; expected uint32, int32, or float32") + + +def _prepare_values(values: np.ndarray) -> Tuple[np.ndarray, np.dtype]: + array = np.asarray(values) + if array.ndim != 1: + raise ValueError("radix_sort expects a 1D array of values when provided") + + dtype = array.dtype + if dtype not in (np.uint32, np.int32, np.float32): + raise TypeError(f"Unsupported value dtype {dtype}; expected uint32, int32, or float32") + + return array, dtype + + +_SHADER_DIR = Path(__file__).resolve().parent / "shaders" + + +def _load_shader_source(filename: str) -> str: + shader_path = _SHADER_DIR / filename + return shader_path.read_text(encoding="utf-8") + + +@dataclass(frozen=True) +class _SorterKey: + bits_per_pass: int + block_size: int + group_size: int + + +class HistogramRadixSort: + """GPU histogram radix sort implemented with Slang + Vulkan.""" + + def __init__(self, *, bits_per_pass: int = 6, block_size: int = 64, group_size: int = 128) -> None: + if bits_per_pass <= 0: + raise ValueError("bits_per_pass must be positive") + if bits_per_pass > 8: + raise ValueError("bits_per_pass must be <= 8 to fit within MAX_HIST_SIZE") + if group_size != 128: + raise ValueError("This implementation currently requires group_size == 128") + if block_size <= 0 or block_size > 1024: + raise ValueError("block_size must be within (0, 1024]") + + self.bits_per_pass = bits_per_pass + self.block_size = block_size + self.group_size = group_size + self.histogram_size = 1 << bits_per_pass + self.last_stage_timings = None + self.last_validation_errors: Optional[int] = None + + def inject_defines(filename: str, *, with_group: bool = False, with_histogram: bool = False) -> str: + defines = [] + if with_group: + defines.append(f"#define TF_GROUP_SIZE {self.group_size}u") + if with_histogram: + defines.append(f"#define TF_HISTOGRAM_SIZE {self.histogram_size}u") + source = _load_shader_source(filename) + if defines: + return "\n".join(defines) + "\n" + source + return source + + self._map_to_uint_program = tf.createComputeProgramFromSlang( + "radix_map_to_uint", + inject_defines("map_to_uint.slang", with_group=True), + "csMapToUint", + ro_count=1, + rw_count=1, + push_constant_size=8, + ) + self._map_from_uint_program = tf.createComputeProgramFromSlang( + "radix_map_from_uint", + inject_defines("map_from_uint.slang", with_group=True), + "csMapFromUint", + ro_count=1, + rw_count=1, + push_constant_size=8, + ) + + self._histogram_program = tf.createComputeProgramFromSlang( + "radix_histogram", + inject_defines("histogram.slang", with_group=True), + "csHistogram", + ro_count=1, + rw_count=1, + push_constant_size=32, + ) + self._unpack_program = tf.createComputeProgramFromSlang( + "radix_unpack", + _load_shader_source("unpack.slang"), + "csUnpack", + ro_count=1, + rw_count=1, + push_constant_size=32, + ) + self._prefix_local_program = tf.createComputeProgramFromSlang( + "radix_prefix_local", + _load_shader_source("prefix_local.slang"), + "csPrefixLocal", + ro_count=1, + rw_count=2, + push_constant_size=32, + ) + self._prefix_blocks_program = tf.createComputeProgramFromSlang( + "radix_prefix_blocks", + _load_shader_source("prefix_block.slang"), + "csPrefixBlocks", + ro_count=1, + rw_count=1, + push_constant_size=32, + ) + self._prefix_accum_program = tf.createComputeProgramFromSlang( + "radix_prefix_accum", + _load_shader_source("prefix_accum.slang"), + "csPrefixAccumulate", + ro_count=1, + rw_count=1, + push_constant_size=32, + ) + self._bucket_scan_program = tf.createComputeProgramFromSlang( + "radix_bucket_scan", + _load_shader_source("bucket_scan.slang"), + "csBucketScan", + ro_count=1, + rw_count=1, + push_constant_size=32, + ) + scatter_source = inject_defines("scatter.slang", with_group=True, with_histogram=True) + self._scatter_program = tf.createComputeProgramFromSlang( + "radix_scatter", + scatter_source, + "csScatter", + ro_count=4, + rw_count=2, + push_constant_size=32, + ) + + # Validation program: checks if adjacent key pairs are sorted (after mapping back to original type) + self._validate_program = tf.createComputeProgramFromSlang( + "radix_validate_sorted", + inject_defines("validate_sorted.slang", with_group=True), + "csValidate", + ro_count=1, + rw_count=1, + push_constant_size=8, + ) + + self._dummy_values_buffer = tf.createBuffer(1, 4, False) + + def sort( + self, + keys: np.ndarray, + values: Optional[np.ndarray] = None, + *, + max_bits: int = 32, + collect_stage_timings: bool = False, + validate: bool = False, + return_arrays: bool = True, + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + keys_array, key_dtype, key_kind = _prepare_keys(keys) + element_count = int(keys_array.shape[0]) + + if values is not None: + values_array, values_dtype = _prepare_values(values) + if values_array.shape[0] != element_count: + raise ValueError("values must have the same length as keys") + else: + values_array = None + values_dtype = None + + if element_count == 0: + empty_keys = keys_array.copy() + self.last_stage_timings = {} if collect_stage_timings else None + if validate: + self.last_validation_errors = 0 + if values_array is None: + return empty_keys, None + return empty_keys, values_array.copy() + + max_bits = int(min(max_bits, 32)) + histogram_size = self.histogram_size + mask = np.uint32(histogram_size - 1) + + num_groups = max((element_count + self.group_size - 1) // self.group_size, 1) + block_count = max((num_groups + self.block_size - 1) // self.block_size, 1) + packed_count = (histogram_size + 3) // 4 + passes = max((max_bits + self.bits_per_pass - 1) // self.bits_per_pass, 1) + + params_array = np.zeros(8, dtype=np.uint32) + params_array[0] = np.uint32(element_count) + params_array[1] = np.uint32(histogram_size) + params_array[3] = mask + params_array[4] = np.uint32(num_groups) + params_array[5] = np.uint32(self.block_size) + params_array[6] = np.uint32(block_count) + params_array[7] = np.uint32(1 if values_array is not None else 0) + + map_params = np.zeros(2, dtype=np.uint32) + map_params[0] = np.uint32(element_count) + map_params[1] = _TYPE_CODES[key_kind] + + if collect_stage_timings: + stage_totals = { + "map_to_uint": 0.0, + "histogram": 0.0, + "unpack": 0.0, + "prefix_local": 0.0, + "prefix_blocks": 0.0, + "prefix_accum": 0.0, + "bucket_scan": 0.0, + "scatter": 0.0, + "map_from_uint": 0.0, + } + else: + stage_totals = {} + + key_buffers = [tf.createBuffer(max(element_count, 1), 4, False) for _ in range(2)] + key_buffers[0].setData(keys_array) + + if values_array is not None: + value_buffers = [tf.createBuffer(max(element_count, 1), 4, False) for _ in range(2)] + value_buffers[0].setData(values_array) + else: + dummy = self._dummy_values_buffer + value_buffers = [dummy, dummy] + + packed_hist_buffer = tf.createBuffer(max(packed_count * num_groups, 1), 4, False) + group_hist_buffer = tf.createBuffer(max(histogram_size * num_groups, 1), 4, False) + prefix_buffer = tf.createBuffer(max(histogram_size * num_groups, 1), 4, False) + block_totals_buffer = tf.createBuffer(max(histogram_size * block_count, 1), 4, False) + block_prefix_buffer = tf.createBuffer(max(histogram_size * block_count, 1), 4, False) + bucket_scan_buffer = tf.createBuffer(max(histogram_size, 1), 4, False) + + map_groups = _dispatch_groups(element_count, self.group_size) + reduction_group_size = 64 + unpack_groups = _dispatch_groups(histogram_size * num_groups, reduction_group_size) + prefix_local_groups = _dispatch_groups(histogram_size * block_count, reduction_group_size) + prefix_block_groups = _dispatch_groups(histogram_size, reduction_group_size) + prefix_accum_groups = _dispatch_groups(histogram_size * block_count, reduction_group_size) + bucket_scan_groups = _dispatch_groups(histogram_size, reduction_group_size) + scatter_groups = num_groups + histogram_groups = num_groups + + # Total pass timer starts at the first dispatch and ends after the last (map_from_uint) + total_start = time.perf_counter() if collect_stage_timings else None + start = time.perf_counter() if collect_stage_timings else None + self._map_to_uint_program.run( + [key_buffers[0]], + [key_buffers[1]], + map_groups, + map_params, + ) + if collect_stage_timings and start is not None: + stage_totals["map_to_uint"] += time.perf_counter() - start + + key_in = key_buffers[1] + key_out = key_buffers[0] + val_in, val_out = value_buffers + + for pass_index in range(passes): + params_array[2] = np.uint32(pass_index * self.bits_per_pass) + + start = time.perf_counter() if collect_stage_timings else None + self._histogram_program.run( + [key_in], + [packed_hist_buffer], + histogram_groups, + params_array, + ) + if collect_stage_timings and start is not None: + stage_totals["histogram"] += time.perf_counter() - start + + start = time.perf_counter() if collect_stage_timings else None + self._unpack_program.run( + [packed_hist_buffer], + [group_hist_buffer], + unpack_groups, + params_array, + ) + if collect_stage_timings and start is not None: + stage_totals["unpack"] += time.perf_counter() - start + + start = time.perf_counter() if collect_stage_timings else None + self._prefix_local_program.run( + [group_hist_buffer], + [prefix_buffer, block_totals_buffer], + prefix_local_groups, + params_array, + ) + if collect_stage_timings and start is not None: + stage_totals["prefix_local"] += time.perf_counter() - start + + start = time.perf_counter() if collect_stage_timings else None + self._prefix_blocks_program.run( + [block_totals_buffer], + [block_prefix_buffer], + prefix_block_groups, + params_array, + ) + if collect_stage_timings and start is not None: + stage_totals["prefix_blocks"] += time.perf_counter() - start + + start = time.perf_counter() if collect_stage_timings else None + self._prefix_accum_program.run( + [block_prefix_buffer], + [prefix_buffer], + prefix_accum_groups, + params_array, + ) + if collect_stage_timings and start is not None: + stage_totals["prefix_accum"] += time.perf_counter() - start + + start = time.perf_counter() if collect_stage_timings else None + self._bucket_scan_program.run( + [prefix_buffer], + [bucket_scan_buffer], + bucket_scan_groups, + params_array, + ) + if collect_stage_timings and start is not None: + stage_totals["bucket_scan"] += time.perf_counter() - start + + start = time.perf_counter() if collect_stage_timings else None + self._scatter_program.run( + [key_in, val_in, prefix_buffer, bucket_scan_buffer], + [key_out, val_out], + scatter_groups, + params_array, + ) + if collect_stage_timings and start is not None: + stage_totals["scatter"] += time.perf_counter() - start + + key_in, key_out = key_out, key_in + if values_array is not None: + val_in, val_out = val_out, val_in + + start = time.perf_counter() if collect_stage_timings else None + self._map_from_uint_program.run( + [key_in], + [key_out], + map_groups, + map_params, + ) + if collect_stage_timings and start is not None: + stage_totals["map_from_uint"] += time.perf_counter() - start + + # Record total time from first to last dispatch in the sort pass + if collect_stage_timings and total_start is not None: + stage_totals["total_pass"] = time.perf_counter() - total_start + + # Optional GPU-side validation: check adjacent pairs and atomically count violations + if validate: + validate_params = np.zeros(2, dtype=np.uint32) + validate_params[0] = np.uint32(element_count) + validate_params[1] = map_params[1] # type code + + error_buf = tf.createBuffer(1, 4, False) + error_zero = np.zeros(1, dtype=np.uint32) + error_buf.setData(error_zero) + + # Reuse map_groups; kernel early-outs for i >= n-1 + self._validate_program.run( + [key_out], + [error_buf], + map_groups, + validate_params, + ) + + error_count = int(error_buf.getData(np.dtype(np.uint32), 1)[0]) + self.last_validation_errors = error_count + + if return_arrays: + sorted_keys = key_out.getData(key_dtype, element_count) + if values_array is not None and values_dtype is not None: + sorted_values = val_in.getData(values_dtype, element_count) + else: + sorted_values = None + else: + # Avoid full readback when not needed by caller + sorted_keys = np.empty(0, dtype=key_dtype) + sorted_values = (np.empty(0, dtype=values_dtype) if values_array is not None and values_dtype is not None else None) + + self.last_stage_timings = stage_totals if collect_stage_timings else None + return sorted_keys, sorted_values + + +_SORTER_CACHE: Dict[_SorterKey, HistogramRadixSort] = {} + + +def _get_sorter(bits_per_pass: int, block_size: int, group_size: int) -> HistogramRadixSort: + key = _SorterKey(bits_per_pass, block_size, group_size) + sorter = _SORTER_CACHE.get(key) + if sorter is None: + sorter = HistogramRadixSort(bits_per_pass=bits_per_pass, block_size=block_size, group_size=group_size) + _SORTER_CACHE[key] = sorter + return sorter + + +def radix_sort( + keys: np.ndarray, + values: Optional[np.ndarray] = None, + *, + bits_per_pass: int = 6, + max_bits: int = 32, + block_size: int = 64, + group_size: int = 128, +): + """Run the GPU histogram radix sort on the provided keys (and optional values). + + Returns the sorted keys, and when ``values`` is provided also returns the permuted values. + """ + + sorter = _get_sorter(bits_per_pass, block_size, group_size) + sorted_keys, sorted_values = sorter.sort(keys, values, max_bits=max_bits) + if values is None: + return sorted_keys + return sorted_keys, sorted_values diff --git a/external/glad b/external/glad deleted file mode 160000 index 658f48e7..00000000 --- a/external/glad +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 658f48e72aee3c6582e80b05ac0f8787a64fe6bb diff --git a/setup_python_env.cmd b/setup_python_env.cmd new file mode 100644 index 00000000..eaf88469 --- /dev/null +++ b/setup_python_env.cmd @@ -0,0 +1,65 @@ +@echo off +setlocal enableextensions + +set "REQUESTED_VERSION=%~1" +if defined REQUESTED_VERSION ( + set "PYTHON_VERSION=%REQUESTED_VERSION%" + set "VERSION_WAS_EXPLICIT=1" +) else ( + set "PYTHON_VERSION=3.12" + set "VERSION_WAS_EXPLICIT=0" +) + +set "SCRIPT_DIR=%~dp0" +if not defined SCRIPT_DIR set "SCRIPT_DIR=.\" +for %%I in ("%SCRIPT_DIR%.") do set "REPO_ROOT=%%~fI" + +set "VENV_DIR=%REPO_ROOT%\.venv" +set "VENV_PYTHON=%VENV_DIR%\Scripts\python.exe" + +set "CREATE_VENV_CMD=py -%PYTHON_VERSION% -m venv" +py -%PYTHON_VERSION% -c "import sys" >nul 2>&1 +if errorlevel 1 ( + if "%VERSION_WAS_EXPLICIT%"=="1" ( + echo [TensorFrost] Python %PYTHON_VERSION% is not available via the launcher. Install it or pass a different version. + exit /b 1 + ) else ( + echo [TensorFrost] Python %PYTHON_VERSION% not found; falling back to default interpreter for venv creation. + set "CREATE_VENV_CMD=py -m venv" + ) +) + +if exist "%VENV_PYTHON%" ( + echo [TensorFrost] Using existing virtual environment at "%VENV_DIR%" +) else ( + echo [TensorFrost] Creating virtual environment at "%VENV_DIR%" + %CREATE_VENV_CMD% "%VENV_DIR%" + if errorlevel 1 ( + echo [TensorFrost] Failed to create virtual environment. + exit /b 1 + ) +) + +if not exist "%VENV_PYTHON%" ( + echo [TensorFrost] Could not find python interpreter in "%VENV_DIR%". + exit /b 1 +) + +echo [TensorFrost] Upgrading pip inside the virtual environment... +"%VENV_PYTHON%" -m pip install --upgrade pip +if errorlevel 1 ( + echo [TensorFrost] Failed to upgrade pip. + exit /b 1 +) + +echo [TensorFrost] Installing TensorFrost in editable mode (verbose)... +"%VENV_PYTHON%" -m pip install -v -e "%REPO_ROOT%\Python" +if errorlevel 1 ( + echo [TensorFrost] Editable install failed. + exit /b 1 +) + +echo [TensorFrost] TensorFrost development environment ready. +echo [TensorFrost] Activate it with "%VENV_DIR%\Scripts\activate.bat" or "pwsh %VENV_DIR%\Scripts\Activate.ps1". + +exit /b 0 diff --git a/tests/imgui_test.py b/tests/imgui_test.py new file mode 100644 index 00000000..11b33d2d --- /dev/null +++ b/tests/imgui_test.py @@ -0,0 +1,156 @@ +import unittest + +import numpy as np +import TensorFrost as tf + + +def _should_skip_for_backend(exc: RuntimeError) -> bool: + message = str(exc).lower() + keywords = ( + "not initialized", + "glfw", + "surface", + "no suitable", + "unavailable", + ) + return any(token in message for token in keywords) + + +class _ManagedWindow: + def __init__(self, width=320, height=240, title="ImGui Test Window"): + self._width = width + self._height = height + self._title = title + self._window = None + + def __enter__(self): + try: + self._window = tf.createWindow(self._width, self._height, self._title) + except RuntimeError as exc: + if _should_skip_for_backend(exc): + raise unittest.SkipTest(f"Window backend unavailable: {exc}") from exc + raise + return self._window + + def __exit__(self, exc_type, exc, tb): + if self._window is not None: + try: + self._window.close() + except Exception: + pass + self._window = None + return False + + +def managed_window(width=320, height=240, title="ImGui Test Window"): + return _ManagedWindow(width, height, title) + + +class ImGuiIntegrationTest(unittest.TestCase): + def test_imgui_basic_widgets(self): + with managed_window() as win: + if win.imgui_begin_main_menu_bar(): + if win.imgui_begin_menu("Root"): + self.assertIn(win.imgui_menu_item("Item"), (True, False)) + win.imgui_end_menu() + win.imgui_end_main_menu_bar() + + visible, open_flag = win.imgui_begin("Main Panel") + self.assertTrue(visible, "ImGui window should be visible on begin") + self.assertIsNone(open_flag, "Default begin call should return None for open flag") + + win.imgui_text("Hello from TensorFrost tests") + win.imgui_same_line() + win.imgui_text("Inline") + win.imgui_spacing() + win.imgui_separator() + win.imgui_indent() + child_visible = win.imgui_begin_child("child", size=(120.0, 60.0), border=False) + self.assertIsInstance(child_visible, bool) + if child_visible: + win.imgui_text_wrapped("Wrapped text inside child region to check bindings work as expected.") + win.imgui_bullet_text("Bullet item content") + win.imgui_end_child() + win.imgui_unindent() + win.imgui_text_colored((1.0, 0.0, 0.0, 1.0), "Colored text") + alpha_scale = win.imgui_get_font_global_scale() + win.imgui_set_font_global_scale(alpha_scale) + style_color = win.imgui_get_style_color_vec4(0) + self.assertEqual(len(style_color), 4) + win.imgui_push_style_color(0, style_color) + win.imgui_push_style_var_float(0, 0.95) + win.imgui_push_style_var_vec2(2, (4.0, 4.0)) + win.imgui_pop_style_var(2) + win.imgui_pop_style_color() + win.imgui_set_style_color_vec4(0, style_color) + + button_pressed = win.imgui_button("Press Me") + self.assertIn(button_pressed, (True, False)) + + self.assertTrue(win.imgui_checkbox("Checkbox", True)) + + slider_value = win.imgui_slider_float("Float Slider", 0.25, 0.0, 1.0) + self.assertGreaterEqual(slider_value, 0.0) + self.assertLessEqual(slider_value, 1.0) + + data = np.linspace(0.0, 1.0, 32, dtype=np.float32) + win.imgui_plot_lines("Plot", data, overlay_text="test", graph_size=(120.0, 40.0)) + + text_changed, text_value = win.imgui_input_text("Input", "hello") + self.assertIsInstance(text_changed, bool) + self.assertIsInstance(text_value, str) + + updated_int_input = win.imgui_input_int("Input Int", 10) + self.assertIsInstance(updated_int_input, int) + + updated_float_input = win.imgui_input_float("Input Float", 3.14) + self.assertIsInstance(updated_float_input, float) + + color_changed3, rgb = win.imgui_color_edit3("Color3", (0.1, 0.2, 0.3)) + self.assertIsInstance(color_changed3, bool) + self.assertEqual(len(rgb), 3) + + color_changed4, rgba = win.imgui_color_edit4("Color4", (0.1, 0.2, 0.3, 0.4)) + self.assertIsInstance(color_changed4, bool) + self.assertEqual(len(rgba), 4) + + if win.imgui_begin_menu_bar(): + if win.imgui_begin_menu("File"): + self.assertIn(win.imgui_menu_item("New"), (True, False)) + win.imgui_end_menu() + win.imgui_end_menu_bar() + + win.imgui_open_popup("ContextPopup") + if win.imgui_begin_popup("ContextPopup"): + win.imgui_text("Popup body") + win.imgui_close_current_popup() + win.imgui_end_popup() + + win.imgui_open_popup("ModalPopup") + visible_modal, modal_open = win.imgui_begin_popup_modal("ModalPopup", open=True) + self.assertIn(visible_modal, (True, False)) + self.assertIsInstance(modal_open, bool) + if visible_modal: + win.imgui_text("Modal body") + win.imgui_close_current_popup() + win.imgui_end_popup() + + win.imgui_add_background_text("BG", (10.0, 10.0), (1.0, 1.0, 1.0, 1.0)) + win.imgui_scale_all_sizes(1.0) + win.imgui_end() + + visible_secondary, open_flag_secondary = win.imgui_begin("Secondary", open=True) + self.assertTrue(visible_secondary) + self.assertIsInstance(open_flag_secondary, bool) + win.imgui_text("Secondary window contents") + updated_int = win.imgui_slider_int("Int Slider", 5, 0, 10) + self.assertGreaterEqual(updated_int, 0) + self.assertLessEqual(updated_int, 10) + win.imgui_end() + + win.present() + self.assertIsInstance(win.isOpen(), bool) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/autograd_test.py b/tests/legacy/autograd_test.py similarity index 98% rename from tests/autograd_test.py rename to tests/legacy/autograd_test.py index 46c66e61..c8c4f0c9 100644 --- a/tests/autograd_test.py +++ b/tests/legacy/autograd_test.py @@ -1,3 +1,5 @@ +"""Autotest temporarily disabled pending updates. + import TensorFrost as tf import numpy as np import torch @@ -234,4 +236,6 @@ def test_autograd(self): self.assertTrue(np.allclose(yhat_tf.numpy, yhat.detach().numpy(), atol=1e-5)) for i, param in enumerate(model_torch.parameters()): self.assertTrue(np.allclose(tf_grads.grad[i].numpy, param.grad.detach().numpy(), atol=1e-3)) - self.assertTrue(np.allclose(tf_grads.net.parameters()[i].numpy, param.detach().numpy(), atol=1e-3)) \ No newline at end of file + self.assertTrue(np.allclose(tf_grads.net.parameters()[i].numpy, param.detach().numpy(), atol=1e-3)) +""" + diff --git a/tests/linalg_test.py b/tests/legacy/linalg_test.py similarity index 89% rename from tests/linalg_test.py rename to tests/legacy/linalg_test.py index e40ab565..8ea11f74 100644 --- a/tests/linalg_test.py +++ b/tests/legacy/linalg_test.py @@ -1,3 +1,5 @@ +"""Autotest temporarily disabled pending updates. + import numpy as np import TensorFrost as tf import unittest @@ -104,10 +106,12 @@ def test_qr_inversion(self): norm_error = np.linalg.norm(np.dot(Q, R) - np.dot(Qnp, Rnp)) print("QR decomposition error: ", norm_error) self.assertTrue(norm_error < 1e-5) - norm_error = np.linalg.norm(Rinv - Rinvtf) - print("Triangular matrix inversion error: ", norm_error) - self.assertTrue(norm_error < 5e-3) - norm_error = np.linalg.norm(Ainv - Ainvtf) - print("Matrix inversion error: ", norm_error) - self.assertTrue(norm_error < 5e-3) + norm_error = np.linalg.norm(Rinv - Rinvtf) + print("Triangular matrix inversion error: ", norm_error) + self.assertTrue(norm_error < 5e-3) + norm_error = np.linalg.norm(Ainv - Ainvtf) + print("Matrix inversion error: ", norm_error) + self.assertTrue(norm_error < 5e-3) +""" + diff --git a/tests/reshape_reduction_test.py b/tests/legacy/reshape_reduction_test.py similarity index 96% rename from tests/reshape_reduction_test.py rename to tests/legacy/reshape_reduction_test.py index c2ffc3d0..54ff724e 100644 --- a/tests/reshape_reduction_test.py +++ b/tests/legacy/reshape_reduction_test.py @@ -1,3 +1,5 @@ +"""Autotest temporarily disabled pending updates. + import TensorFrost as tf import numpy as np import unittest @@ -46,4 +48,6 @@ def test_reduction_reshape(self): self.assertTrue(np.allclose(norm_tf.numpy, norm_np)) self.assertTrue(np.allclose(total_max_tf.numpy, total_max_np)) self.assertTrue(np.allclose(total_min_tf.numpy, total_min_np)) - self.assertTrue(np.allclose(mean_block_tf.numpy, mean_block_np)) \ No newline at end of file + self.assertTrue(np.allclose(mean_block_tf.numpy, mean_block_np)) +""" + diff --git a/tests/sorting_opengl_test.py b/tests/legacy/sorting_opengl_test.py similarity index 96% rename from tests/sorting_opengl_test.py rename to tests/legacy/sorting_opengl_test.py index 8cdb31e8..25e94f29 100644 --- a/tests/sorting_opengl_test.py +++ b/tests/legacy/sorting_opengl_test.py @@ -1,3 +1,5 @@ +"""Autotest temporarily disabled pending updates. + # %% import numpy as np import TensorFrost as tf @@ -68,4 +70,5 @@ def test_sorting(self): # check if the results are the same error_radix = np.sum(np.abs(sorted_keys0.numpy - sorted_keys2)) print("Radix float errors: ", error_radix) - self.assertTrue(error_radix == 0) \ No newline at end of file + self.assertTrue(error_radix == 0) +""" diff --git a/tests/sorting_test.py b/tests/legacy/sorting_test.py similarity index 96% rename from tests/sorting_test.py rename to tests/legacy/sorting_test.py index 2fa4a3d0..623f86f9 100644 --- a/tests/sorting_test.py +++ b/tests/legacy/sorting_test.py @@ -1,3 +1,5 @@ +"""Autotest temporarily disabled pending updates. + # %% import numpy as np import TensorFrost as tf @@ -68,4 +70,5 @@ def test_sorting(self): # check if the results are the same error_radix = np.sum(np.abs(sorted_keys0.numpy - sorted_keys2)) print("Radix float errors: ", error_radix) - self.assertTrue(error_radix == 0) \ No newline at end of file + self.assertTrue(error_radix == 0) +""" diff --git a/tests/split_dim_test.py b/tests/legacy/split_dim_test.py similarity index 85% rename from tests/split_dim_test.py rename to tests/legacy/split_dim_test.py index 12d05270..211fd699 100644 --- a/tests/split_dim_test.py +++ b/tests/legacy/split_dim_test.py @@ -1,3 +1,5 @@ +"""Autotest temporarily disabled pending updates. + import numpy as np import TensorFrost as tf import unittest @@ -21,4 +23,5 @@ def test_split_dim(self): print(merged.shape) self.assertTrue(merged.shape == [128, 128, 32]) print(splitted.shape) - self.assertTrue(splitted.shape == [4, 32, 128, 32]) \ No newline at end of file + self.assertTrue(splitted.shape == [4, 32, 128, 32]) +""" diff --git a/tests/renderdoc_test.py b/tests/renderdoc_test.py new file mode 100644 index 00000000..8870ee51 --- /dev/null +++ b/tests/renderdoc_test.py @@ -0,0 +1,25 @@ +import unittest + +import TensorFrost as tf + + +class RenderDocBindingTest(unittest.TestCase): + def test_renderdoc_functions_exist(self): + self.assertTrue(hasattr(tf, "renderdoc_start_capture")) + self.assertTrue(hasattr(tf, "renderdoc_end_capture")) + self.assertTrue(hasattr(tf, "renderdoc_is_available")) + + def test_renderdoc_capture_calls(self): + # Calls shouldn't raise even when RenderDoc isn't attached. + tf.renderdoc_start_capture() + path = tf.renderdoc_end_capture() + self.assertIsInstance(path, str) + path = tf.renderdoc_end_capture(launch_replay_ui=False) + self.assertIsInstance(path, str) + + def test_renderdoc_available_returns_bool(self): + self.assertIsInstance(tf.renderdoc_is_available(), bool) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/slang_compile_test.py b/tests/slang_compile_test.py new file mode 100644 index 00000000..bb2c2f8a --- /dev/null +++ b/tests/slang_compile_test.py @@ -0,0 +1,114 @@ +import unittest +from pathlib import Path + +import numpy as np + +import TensorFrost as tf + + +_SIMPLE_SLANG_SHADER = """[[vk::binding(0,0)]] StructuredBuffer InputBuffer : register(t0, space0); +[[vk::binding(1,0)]] RWStructuredBuffer OutputBuffer : register(u1, space0); + +[numthreads(64, 1, 1)] +void csMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + if (dispatchThreadID.x != 0u) + return; + + uint value = InputBuffer[0]; + OutputBuffer[0] = value + 1u; +} +""" + +_REQUIRED_ARTIFACTS = ( + "slang.dll", + "slang-glslang.dll", + "spirv-opt.exe", +) + + +def _runtime_dir() -> Path: + return Path(tf.__file__).resolve().parent + + +def _missing_runtime_artifacts() -> list[str]: + runtime_dir = _runtime_dir() + missing: list[str] = [] + for artifact in _REQUIRED_ARTIFACTS: + release_candidate = runtime_dir / artifact + debug_candidate = runtime_dir / release_candidate.with_name(release_candidate.stem + "d" + release_candidate.suffix).name + if not release_candidate.exists() and not debug_candidate.exists(): + missing.append(artifact) + return missing + +def _should_skip_for_backend(exc: Exception) -> bool: + message = str(exc).lower() + keywords = ( + "glfw", + "vulkan", + "no physical devices", + "no suitable", + "device", + "surface", + "swapchain", + ) + return any(token in message for token in keywords) + + +class SlangCompilationTest(unittest.TestCase): + def test_compile_and_execute_simple_shader(self) -> None: + thread_count = 1 + local_size = 64 + group_count = max((thread_count + local_size - 1) // local_size, 1) + + missing_artifacts = _missing_runtime_artifacts() + if missing_artifacts: # pragma: no cover - environment not staged + pretty = ", ".join(missing_artifacts) + self.skipTest( + "Slang runtime components missing: " + f"{pretty}. Re-run setup_python_env.cmd or rebuild the TensorFrost target to stage runtimes." + ) + + try: + readonly_buffer = tf.createBuffer(thread_count, 4, True) + except RuntimeError as exc: # pragma: no cover - Vulkan not available + self.skipTest(f"Vulkan buffer creation failed: {exc}") + + try: + readwrite_buffer = tf.createBuffer(thread_count, 4, False) + except RuntimeError as exc: # pragma: no cover - Vulkan not available + self.skipTest(f"Vulkan buffer creation failed: {exc}") + + try: + program = tf.createComputeProgramFromSlang( + "tensorfrost_test_shader", + _SIMPLE_SLANG_SHADER, + "csMain", + ro_count=1, + rw_count=1, + ) + except RuntimeError as exc: + if _should_skip_for_backend(exc): # pragma: no cover - Vulkan not available + self.skipTest(f"Slang compilation backend unavailable: {exc}") + raise + + readonly_buffer.setData(np.array([7], dtype=np.uint32)) + readwrite_buffer.setData(np.zeros(1, dtype=np.uint32)) + + program.run([readonly_buffer], [readwrite_buffer], group_count) + + result = readwrite_buffer.getData(np.dtype(np.uint32), thread_count) + self.assertEqual(result.shape, (thread_count,)) + self.assertEqual(int(result[0]), 8) + + readonly_buffer = None + readwrite_buffer = None + program = None + + import gc + + gc.collect() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/vulkan_window_test.py b/tests/vulkan_window_test.py new file mode 100644 index 00000000..6d1169b5 --- /dev/null +++ b/tests/vulkan_window_test.py @@ -0,0 +1,68 @@ +import unittest + +import numpy as np + +import TensorFrost as tf + + +_SIMPLE_SLANG = r""" +[[vk::binding(0,0)]] RWStructuredBuffer Pixels : register(u0, space0); + +[numthreads(64,1,1)] +void csMain(uint3 tid : SV_DispatchThreadID) +{ + uint idx = tid.x; + if (idx >= Pixels.length()) return; + Pixels[idx] = 0xff3366ff; +} +""" + + +class VulkanWindowTest(unittest.TestCase): + def test_compute_dispatch_and_window_present(self): + width = height = 8 + thread_count = width * height + local_size = 64 + group_count = max((thread_count + local_size - 1) // local_size, 1) + + try: + pixel_buffer = tf.createBuffer(thread_count, 4, False) + except RuntimeError as exc: # pragma: no cover - Vulkan not available + self.skipTest(f"Vulkan buffer creation failed: {exc}") + + program = None + try: + program = tf.createComputeProgramFromSlang( + "window_test_fill", _SIMPLE_SLANG, "csMain", ro_count=0, rw_count=1 + ) + except RuntimeError as exc: # pragma: no cover - Vulkan program creation failed + self.skipTest(f"Vulkan program creation failed: {exc}") + + window = None + program.run([], [pixel_buffer], group_count) + pixels = pixel_buffer.getData(np.dtype(np.uint32), thread_count) + self.assertTrue(np.all(pixels == 0xFF3366FF), "Compute shader did not write expected color") + + try: + window = tf.createWindow(width, height, "TensorFrost Vulkan Test") + except RuntimeError as exc: # pragma: no cover - Vulkan window not available + self.skipTest(f"Vulkan window creation failed: {exc}") + + # Present once to ensure the binding path is exercised. + window.drawBuffer(pixel_buffer, width, height) + self.assertTrue(window.isOpen(), "Window should report as open after initial present") + + if window is not None: + window.close() + window = None + program = None + pixel_buffer = None + + # Ensure GPU resources are released before interpreter shutdown. + import gc + + gc.collect() + + +if __name__ == "__main__": + unittest.main()