diff --git a/include/infini_operators.h b/include/infini_operators.h index 9a5a2555..aedd3dd1 100644 --- a/include/infini_operators.h +++ b/include/infini_operators.h @@ -16,4 +16,10 @@ #include "ops/rms_norm/rms_norm.h" #include "ops/rotary_embedding/rotary_embedding.h" #include "ops/swiglu/swiglu.h" +#include "ops/reducemax/reducemax.h" +#include "ops/reducemean/reducemean.h" +#include "ops/reducemin/reducemin.h" +#include "ops/clip/clip.h" +#include "ops/where/where.h" +#include "ops/gather/gather.h" #include "tensor/tensor_descriptor.h" diff --git a/include/ops/clip/clip.h b/include/ops/clip/clip.h new file mode 100644 index 00000000..a33c587c --- /dev/null +++ b/include/ops/clip/clip.h @@ -0,0 +1,24 @@ +#ifndef CLIP_H +#define CLIP_H + +#include "../../export.h" +#include "../../operators.h" + +typedef struct ClipDescriptor { + Device device; +} ClipDescriptor; +typedef ClipDescriptor *infiniopClipDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateClipDescriptor(infiniopHandle_t handle, + infiniopClipDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t y, + float* min, + float* max + ); + +__C __export infiniopStatus_t infiniopClip(infiniopClipDescriptor_t desc, void const *x, void *y, void *stream); + +__C __export infiniopStatus_t infiniopDestroyClipDescriptor(infiniopClipDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/include/ops/gather/gather.h b/include/ops/gather/gather.h new file mode 100644 index 00000000..7d510621 --- /dev/null +++ b/include/ops/gather/gather.h @@ -0,0 +1,24 @@ +#ifndef GAHTER_H +#define GAHTER_H + +#include "../../export.h" +#include "../../operators.h" + +typedef struct GatherDescriptor { + Device device; +} GatherDescriptor; +typedef GatherDescriptor *infiniopGatherDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateGatherDescriptor(infiniopHandle_t handle, + infiniopGatherDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t indices, + int64_t axis + ); + +__C __export infiniopStatus_t infiniopGather(infiniopGatherDescriptor_t desc, void const *x, void const *indices, void *y, void *stream); + +__C __export infiniopStatus_t infiniopDestroyGatherDescriptor(infiniopGatherDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/include/ops/reducemax/reducemax.h b/include/ops/reducemax/reducemax.h new file mode 100644 index 00000000..60da5030 --- /dev/null +++ b/include/ops/reducemax/reducemax.h @@ -0,0 +1,25 @@ +#ifndef REDUCEMAX_H +#define REDUCEMAX_H + +#include "../../export.h" +#include "../../operators.h" + +typedef struct ReducemaxDescriptor { + Device device; +} ReducemaxDescriptor; +typedef ReducemaxDescriptor *infiniopReducemaxDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateReducemaxDescriptor(infiniopHandle_t handle, + infiniopReducemaxDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int64_t const *axes, + uint64_t n, + bool keepdims, + bool noop_with_empty_axes + ); + +__C __export infiniopStatus_t infiniopReducemax(infiniopReducemaxDescriptor_t desc, void *y, const void *x, void *dynamic_axes, uint64_t dynamic_axes_size, void *stream); + +__C __export infiniopStatus_t infiniopDestroyReducemaxDescriptor(infiniopReducemaxDescriptor_t desc); +#endif diff --git a/include/ops/reducemean/reducemean.h b/include/ops/reducemean/reducemean.h new file mode 100644 index 00000000..cfb57913 --- /dev/null +++ b/include/ops/reducemean/reducemean.h @@ -0,0 +1,25 @@ +#ifndef REDUCEMEAN_H +#define REDUCEMEAN_H + +#include "../../export.h" +#include "../../operators.h" + +typedef struct ReducemeanDescriptor { + Device device; +} ReducemeanDescriptor; +typedef ReducemeanDescriptor *infiniopReducemeanDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateReducemeanDescriptor(infiniopHandle_t handle, + infiniopReducemeanDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int64_t const *axes, + uint64_t n, + bool keepdims, + bool noop_with_empty_axes + ); + +__C __export infiniopStatus_t infiniopReducemean(infiniopReducemeanDescriptor_t desc, void *dst, const void *src, void *dynamic_axes, uint64_t dynamic_axes_size, void *stream); + +__C __export infiniopStatus_t infiniopDestroyReducemeanDescriptor(infiniopReducemeanDescriptor_t desc); +#endif diff --git a/include/ops/reducemin/reducemin.h b/include/ops/reducemin/reducemin.h new file mode 100644 index 00000000..51731d7e --- /dev/null +++ b/include/ops/reducemin/reducemin.h @@ -0,0 +1,25 @@ +#ifndef REDUCEMIN_H +#define REDUCEMIN_H + +#include "../../export.h" +#include "../../operators.h" + +typedef struct ReduceminDescriptor { + Device device; +} ReduceminDescriptor; +typedef ReduceminDescriptor *infiniopReduceminDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateReduceminDescriptor(infiniopHandle_t handle, + infiniopReduceminDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t dst, + infiniopTensorDescriptor_t src, + int64_t const *axes, + uint64_t n, + bool keepdims, + bool noop_with_empty_axes + ); + +__C __export infiniopStatus_t infiniopReducemin(infiniopReduceminDescriptor_t desc, void *dst, const void *src, void *dynamic_axes, uint64_t dynamic_axes_size, void *stream); + +__C __export infiniopStatus_t infiniopDestroyReduceminDescriptor(infiniopReduceminDescriptor_t desc); +#endif diff --git a/include/ops/where/where.h b/include/ops/where/where.h new file mode 100644 index 00000000..ba47d93f --- /dev/null +++ b/include/ops/where/where.h @@ -0,0 +1,24 @@ +#ifndef WHERE_H +#define WHERE_H + +#include "../../export.h" +#include "../../operators.h" + +typedef struct WhereDescriptor { + Device device; +} WhereDescriptor; +typedef WhereDescriptor *infiniopWhereDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateWhereDescriptor(infiniopHandle_t handle, + infiniopWhereDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t dst, + infiniopTensorDescriptor_t src1, + infiniopTensorDescriptor_t src2, + infiniopTensorDescriptor_t condition + ); + +__C __export infiniopStatus_t infiniopWhere(infiniopWhereDescriptor_t desc, void *dst, void const *src1, void const *src2, void const *condition, void *stream); + +__C __export infiniopStatus_t infiniopDestroyWhereDescriptor(infiniopWhereDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/operatorspy/liboperators.py b/operatorspy/liboperators.py index 0909c0cf..3ff8d3ba 100644 --- a/operatorspy/liboperators.py +++ b/operatorspy/liboperators.py @@ -45,6 +45,8 @@ class Handle(Structure): def open_lib(): def find_library_in_ld_path(library_name): ld_library_path = LIB_OPERATORS_DIR + + print(LIB_OPERATORS_DIR) paths = ld_library_path.split(os.pathsep) for path in paths: full_path = os.path.join(path, library_name) diff --git a/operatorspy/tests/clip.py b/operatorspy/tests/clip.py new file mode 100644 index 00000000..0874abe4 --- /dev/null +++ b/operatorspy/tests/clip.py @@ -0,0 +1,172 @@ +from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64, c_bool, c_float +import ctypes +import sys +import os +import time + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, +) + +from operatorspy.tests.test_utils import get_args +import torch +from typing import Tuple +import numpy as np + +PROFILE = True +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + +class ClipDescriptor(Structure): + _fields_ = [("device", c_int32)] + +infiniopClipDescriptor_t = POINTER(ClipDescriptor) + +def clip(input, min, max): + return torch.clamp(input, min, max) + + +def tuple_to_void_p(py_tuple: Tuple): + array = ctypes.c_int64 * len(py_tuple) + data_array = array(*py_tuple) + return ctypes.cast(data_array, ctypes.c_void_p) + +def test( + lib, + handle, + torch_device, + x_shape, + min, + max, + tensor_dtype=torch.float32 +): + print( + f"Testing clip on {torch_device} with x_shape:{x_shape} dtype:{tensor_dtype} max:{max} min:{min}" + ) + x = torch.randn(x_shape, dtype=torch.float32, device=torch_device) + + output = torch.randn(x_shape, dtype=torch.float32, device=torch_device) + if min != None: + min_t = torch.tensor(min, dtype=torch.float32, device=torch_device) + else: + min_t = torch.tensor(float("-inf"), dtype=torch.float32, device=torch_device) + if max != None: + max_t = torch.tensor(max, dtype=torch.float32, device=torch_device) + else: + max_t = torch.tensor(float("inf"), dtype=torch.float32, device=torch_device) + for i in range(NUM_PRERUN if PROFILE else 1): + if min == None and max == None: + break + ans = clip(x, min_t, max_t) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + _ = clip(x, min_t, max_t) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f"pytorch time: {elapsed :10f}") + x_tensor = to_tensor(x, lib) + y_tensor = to_tensor(output, lib) + descriptor = infiniopClipDescriptor_t() + check_error( + lib.infiniopCreateClipDescriptor( + handle, + ctypes.byref(descriptor), + x_tensor.descriptor, + y_tensor.descriptor, + ctypes.byref(c_float(min)) if min != None else None, + ctypes.byref(c_float(max)) if max != None else None, + ) + ) + #Ss = [1024, 2048, 4096] + x_tensor.descriptor.contents.invalidate() + y_tensor.descriptor.contents.invalidate() + for i in range(NUM_PRERUN if PROFILE else 1): + check_error( + lib.infiniopClip( + descriptor, + x_tensor.data, + y_tensor.data, + None, + ) + ) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + check_error( + lib.infiniopClip( + descriptor, + x_tensor.data, + y_tensor.data, + None, + ) + ) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f"lib time: {elapsed :10f}") + assert torch.allclose(output, ans, atol=0, rtol=0) if max != None or min != None else torch.allclose(output, x, atol=0, rtol=0) + check_error(lib.infiniopDestroyClipDescriptor(descriptor)) + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for x_shape, min, max, tensor_type in test_cases: + test(lib, handle, "cpu", x_shape, min, max, tensor_dtype=tensor_type) + destroy_handle(lib, handle) + +def test_cuda(lib, test_cases): + device = DeviceEnum.DEVICE_CUDA + handle = create_handle(lib, device) + for x_shape, min, max, tensor_type in test_cases: + test(lib, handle, "cuda", x_shape, min, max, tensor_dtype=tensor_type) + destroy_handle(lib, handle) + + +if __name__ == "__main__": + test_cases = [ + ((3, 4), -1, 1, torch.float32), + ((3, 4), None, 1, torch.float32), + ((3, 4), -1, None, torch.float32), + ((3, 4), None, None, torch.float32), + ((16), -1, 1, torch.float32), + # ((1024, 1024), -1, 1, torch.float32), + # ((4096, 4096), -1, 1, torch.float32), + + ((13), -1, 1, torch.float32), + ((3, 4), -1, 1, torch.float16), + ((3, 4), None, 1, torch.float16), + ((3, 4), -1, None, torch.float16), + ((3, 4), None, None, torch.float16), + ((16), -1, 1, torch.float16), + # ((1024, 1024), -1, 1, torch.float16), + # ((4096, 4096), -1, 1, torch.float16), + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateClipDescriptor.restype = c_int32 + lib.infiniopCreateClipDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopClipDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t + ] + lib.infiniopClip.restype = c_int32 + lib.infiniopClip.argtypes = [ + infiniopClipDescriptor_t, + c_void_p, + c_void_p, + c_void_p + ] + lib.infiniopDestroyClipDescriptor.restype = c_int32 + lib.infiniopDestroyClipDescriptor.argtypes = [infiniopClipDescriptor_t] + if args.cuda: + test_cuda(lib, test_cases) + if args.cpu: + test_cpu(lib, test_cases) + print("All tests passed!") \ No newline at end of file diff --git a/operatorspy/tests/gather.py b/operatorspy/tests/gather.py new file mode 100644 index 00000000..2c01cada --- /dev/null +++ b/operatorspy/tests/gather.py @@ -0,0 +1,163 @@ +from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64, c_bool +import ctypes +import sys +import os +import time + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, +) + +from operatorspy.tests.test_utils import get_args +import torch +from typing import Tuple +import numpy as np + +PROFILE = True +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + +class GatherDescriptor(Structure): + _fields_ = [("device", c_int32)] + +infiniopGatherDescriptor_t = POINTER(GatherDescriptor) + +def gather(x, indices, axis = 0): + idx = [slice(None)] * x.ndim + idx[axis] = indices + return x[tuple(idx)] + +def tuple_to_void_p(py_tuple: Tuple): + array = ctypes.c_int64 * len(py_tuple) + data_array = array(*py_tuple) + return ctypes.cast(data_array, ctypes.c_void_p) + +def inferShape(input_shape, indices_shape, axis): + output_shape = input_shape[:axis] + tuple(indices_shape) + input_shape[axis + 1:] + return output_shape + +def test( + lib, + handle, + torch_device, + x_shape, + indices_shape, + axis, + tensor_dtype=torch.float16 +): + print( + f"Testing gather on {torch_device} with x_shape:{x_shape} dtype:{tensor_dtype}" + ) + x = torch.randn(x_shape, dtype=tensor_dtype, device=torch_device) + if isinstance(indices_shape, int): + indices_shape_tuple = (indices_shape,) + else: + indices_shape_tuple = tuple(indices_shape) + indices = torch.randint(0, x.shape[axis], indices_shape_tuple, + device=torch_device).type(torch.int64) + dst = torch.randn(inferShape(x_shape, indices.shape, axis), dtype=tensor_dtype, device=torch_device) + + ans = gather(x, indices, axis) + + x_tensor = to_tensor(x, lib) + indices_tensor = to_tensor(indices, lib) + dst_tensor = to_tensor(dst, lib) + descriptor = infiniopGatherDescriptor_t() + check_error( + lib.infiniopCreateGatherDescriptor( + handle, + ctypes.byref(descriptor), + dst_tensor.descriptor, + x_tensor.descriptor, + indices_tensor.descriptor, + axis + ) + ) + x_tensor.descriptor.contents.invalidate() + indices_tensor.descriptor.contents.invalidate() + dst_tensor.descriptor.contents.invalidate() + for i in range(NUM_PRERUN if PROFILE else 1): + check_error( + lib.infiniopGather( + descriptor, + x_tensor.data, + indices_tensor.data, + dst_tensor.data, + None, + ) + ) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + check_error( + lib.infiniopGather( + descriptor, + x_tensor.data, + indices_tensor.data, + dst_tensor.data, + None, + ) + ) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f"lib time: {elapsed :10f}") + ans = ans.to(torch_device) + assert torch.allclose(dst, ans, atol=0, rtol=0) + check_error(lib.infiniopDestroyGatherDescriptor(descriptor)) + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for x_shape, indices_shape, axis, tensor_dtype in test_cases: + test(lib, handle, "cpu", x_shape, indices_shape, axis, tensor_dtype=tensor_dtype) + destroy_handle(lib, handle) + +def test_cuda(lib, test_cases): + device = DeviceEnum.DEVICE_CUDA + handle = create_handle(lib, device) + for x_shape, indices_shape, axis, tensor_dtype in test_cases: + test(lib, handle, "cuda", x_shape, indices_shape, axis, tensor_dtype=tensor_dtype) + destroy_handle(lib, handle) + + +if __name__ == "__main__": + test_cases = [ + ((3, 4), (2), 0, torch.float32), + ((64, 64), (64, 64), 0, torch.float32), + ((64, 64), (64, 64), 1, torch.float32), + ((2, 3, 4), (2, 2), 1, torch.float32), + ((64, 64), (64, 64), 0, torch.float16), + ((64, 64), (64, 64), 1, torch.float16), + ((8, 8, 8, 8, 8), (8, 8), 0, torch.float16), + ((8, 8, 8, 8, 8), (8, 8), 2, torch.float16), + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateGatherDescriptor.restype = c_int32 + lib.infiniopCreateGatherDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopGatherDescriptor_t), + infiniopTensorDescriptor_t, + ] + lib.infiniopGather.restype = c_int32 + lib.infiniopGather.argtypes = [ + infiniopGatherDescriptor_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyGatherDescriptor.restype = c_int32 + lib.infiniopDestroyGatherDescriptor.argtypes = [infiniopGatherDescriptor_t] + if args.cuda: + test_cuda(lib, test_cases) + if args.cpu: + test_cpu(lib, test_cases) + print("All tests passed!") \ No newline at end of file diff --git a/operatorspy/tests/matmul.py b/operatorspy/tests/matmul.py index ac4b0f7f..a949390e 100644 --- a/operatorspy/tests/matmul.py +++ b/operatorspy/tests/matmul.py @@ -22,9 +22,9 @@ from operatorspy.tests.test_utils import get_args, synchronize_device import torch -PROFILE = False +PROFILE = True NUM_PRERUN = 10 -NUM_ITERATIONS = 1000 +NUM_ITERATIONS = 50 class MatmulDescriptor(Structure): _fields_ = [("device", c_int32)] diff --git a/operatorspy/tests/max_pool.py b/operatorspy/tests/max_pool.py index ffc0bb19..7cd4ef8a 100644 --- a/operatorspy/tests/max_pool.py +++ b/operatorspy/tests/max_pool.py @@ -23,7 +23,7 @@ # constant for control whether profile the pytorch and lib functions # NOTE: need to manually add synchronization function to the lib function, # e.g., cudaDeviceSynchronize() for CUDA -PROFILE = False +PROFILE = True NUM_PRERUN = 10 NUM_ITERATIONS = 1000 diff --git a/operatorspy/tests/reducemax.py b/operatorspy/tests/reducemax.py new file mode 100644 index 00000000..a8873a2a --- /dev/null +++ b/operatorspy/tests/reducemax.py @@ -0,0 +1,224 @@ +from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64, c_bool +import ctypes +import sys +import os +import time + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, +) + +from operatorspy.tests.test_utils import get_args +import torch +from typing import Tuple +import numpy as np + +PROFILE = True +NUM_PRERUN = 1 +NUM_ITERATIONS = 50 + +class ReducemaxDescriptor(Structure): + _fields_ = [("device", c_int32)] + +infiniopReducemaxDescriptor_t = POINTER(ReducemaxDescriptor) + +def reduce_max(input, axis, noop_with_empty_axes, keepdims=True): + if axis == None: + if noop_with_empty_axes: + return input + else: + return torch.amax(input, dim=axis, keepdim=keepdims) + return torch.amax(input, dim=axis, keepdim=keepdims) + +def inferShape(x_shape, axis, noop_with_empty_axes, keepdims=False): + if axis == None: + if noop_with_empty_axes: + return x_shape + else: + if keepdims: + return tuple([1] * len(x_shape)) + else: + return tuple([]) + assert len(axis) <= len(x_shape), "axis out of range" + output_shape = [] + axis = [a if a >= 0 else a + len(x_shape) for a in axis] # 更新 axis 列表中的值 + for a in axis: + assert 0 <= a <= len(x_shape) - 1, "axis out of range" + for i, s in enumerate(x_shape): + if i in axis and keepdims: + output_shape.append(1) + elif i in axis and not keepdims: + continue + else: + output_shape.append(s) + + print(f"output_shape = {output_shape}") + return tuple(output_shape) + +def tuple_to_void_p(py_tuple: Tuple): + array = ctypes.c_int64 * len(py_tuple) + data_array = array(*py_tuple) + return ctypes.cast(data_array, ctypes.c_void_p) + +def test( + lib, + handle, + torch_device, + x_shape, + axes, + dynamic_axes, + noop_with_empty_axes=False, + keepdims=True, + tensor_dtype=torch.float16 +): + print( + f"Testing reducemax on {torch_device} with x_shape:{x_shape} dtype:{tensor_dtype}" + ) + x = torch.randn(x_shape, dtype=tensor_dtype, device=torch_device) + print(f"y_shape = {inferShape(x_shape, axes if dynamic_axes == None else dynamic_axes, noop_with_empty_axes, keepdims)}") + y = torch.full(inferShape(x_shape, axes if dynamic_axes == None else dynamic_axes, noop_with_empty_axes, keepdims), float('-inf'), dtype=tensor_dtype, device=torch_device) + print(f"y_shape = {y.shape}") + for i in range(NUM_PRERUN if PROFILE else 1): + ans = reduce_max(x, axes if dynamic_axes == None else dynamic_axes, noop_with_empty_axes, keepdims) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + _ = reduce_max(x, axes if dynamic_axes == None else dynamic_axes, noop_with_empty_axes, keepdims) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f"pytorch time: {elapsed :10f}") + x_tensor = to_tensor(x, lib) + y_tensor = to_tensor(y, lib) + descriptor = infiniopReducemaxDescriptor_t() + axe = tuple_to_void_p(axes) if axes != None else None + lenth_axes = c_uint64(len(axes)) if axes != None else c_uint64(0) + dynamic_axes_parm = tuple_to_void_p(dynamic_axes) if dynamic_axes != None else None + lenth_dynamic_axes = c_uint64(len(dynamic_axes)) if dynamic_axes != None else c_uint64(0) + check_error( + lib.infiniopCreateReducemaxDescriptor( + handle, + ctypes.byref(descriptor), + y_tensor.descriptor, + x_tensor.descriptor, + axe, + lenth_axes, + c_bool(keepdims), + c_bool(noop_with_empty_axes), + ) + ) + x_tensor.descriptor.contents.invalidate() + y_tensor.descriptor.contents.invalidate() + for i in range(NUM_PRERUN if PROFILE else 1): + check_error( + lib.infiniopReducemax( + descriptor, + y_tensor.data, + x_tensor.data, + dynamic_axes_parm, + lenth_dynamic_axes, + None, + ) + ) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + check_error( + lib.infiniopReducemax( + descriptor, + y_tensor.data, + x_tensor.data, + dynamic_axes_parm, + lenth_dynamic_axes, + None, + ) + ) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f"lib time: {elapsed :10f}") + # print(f"input : {x}") + # print(f"custom op output:{y}") + # print(f"pytorch output:{ans}") + check_error(lib.infiniopDestroyReducemaxDescriptor(descriptor)) + assert torch.allclose(y, ans, atol=0, rtol=1e-3) + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for x_shape, axes, noop_with_empty_axes, keepdims, dynamic_axes, tensor_dtype in test_cases: + test(lib, handle, "cpu", x_shape, axes, dynamic_axes, noop_with_empty_axes, keepdims, tensor_dtype=tensor_dtype) + print("\n") + #test(lib, handle, "cpu", x_shape, axes, tensor_dtype=torch.float32) + destroy_handle(lib, handle) + +def test_cuda(lib, test_cases): + device = DeviceEnum.DEVICE_CUDA + handle = create_handle(lib, device) + for x_shape, axes, noop_with_empty_axes, keepdims, dynamic_axes, tensor_dtype in test_cases: + test(lib, handle, "cuda", x_shape, axes, dynamic_axes, noop_with_empty_axes, keepdims, tensor_dtype=tensor_dtype) + print("\n") + destroy_handle(lib, handle) + +if __name__ == "__main__": + test_cases = [ + # dynamic calc test eg + # ((2, 3, 4, 5), [0, 2], False, True, None), + # ((2, 3, 4, 5), [0, 2], False, True, None), + # #(input_shape, axis, noop_with_empty_axes, keepdims, dynamic_axes) + # ((2, 10, 24, 10), [0, 2], False, True, None), + # # stride = + # ((2, 10, 24, 10), [0, 1], False, True, None), + # ((2, 10, 24, 10), [2, 3], False , True, None), + # ((2, 10, 24, 10), [0, 1, 2, 3], False, True, None), + # # validate attribute noop_with_empty_axes and keepdims + # ((2, 10, 24, 10), None, True, True, None), + # ((2, 10, 24, 10), None, True, False, None), + # ((2, 10, 24, 10), None, False, True, None), + # ((2, 10, 24, 10), None, False, False, None), + # ((2, 3, 4), [0, 1], False, False, None), + #((2, 10, 24, 10), [], True), + #((4,), [0], False, False, None, torch.float32), + ((1000, 300), [0, 1], False, False, None, torch.float16), + ((50, 3), [0, 1], False, False, None, torch.float16), + ((1000, 300), [0, 1], False, False, None, torch.float16), + ((2000, 200, 50), [0, 1], False, True, None, torch.float32), + ((1000, 200, 500), [0, 1], False, True, None, torch.float16), + ((1000, 200, 50), [0, 1], False, True, None, torch.float32), + ((20, 3, 4, 5), [0, 2], False, False, None, torch.float32), + ((20, 30, 40, 5), [0, 2, 3], False, False, None, torch.float32), + ((200, 3, 40, 5), [0, 3], False, False, None, torch.float32), + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateReducemaxDescriptor.restype = c_int32 + lib.infiniopCreateReducemaxDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopReducemaxDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_void_p, + c_uint64, + c_bool, + c_bool + ] + lib.infiniopReducemax.restype = c_int32 + lib.infiniopReducemax.argtypes = [ + infiniopReducemaxDescriptor_t, + c_void_p, + c_void_p, + c_void_p, + c_uint64, + c_void_p, + ] + lib.infiniopDestroyReducemaxDescriptor.restype = c_int32 + lib.infiniopDestroyReducemaxDescriptor.argtypes = [infiniopReducemaxDescriptor_t] + if args.cpu: + test_cpu(lib, test_cases) + if args.cuda: + test_cuda(lib, test_cases) + print("All tests passed!") \ No newline at end of file diff --git a/operatorspy/tests/reducemean.py b/operatorspy/tests/reducemean.py new file mode 100644 index 00000000..009b6646 --- /dev/null +++ b/operatorspy/tests/reducemean.py @@ -0,0 +1,226 @@ +from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64, c_bool +import ctypes +import sys +import os +import time + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, +) + +from operatorspy.tests.test_utils import get_args +import torch +from typing import Tuple +import numpy as np + +PROFILE = True +NUM_PRERUN = 1 +NUM_ITERATIONS = 50 + +class ReducemeanDescriptor(Structure): + _fields_ = [("device", c_int32)] + +infiniopReducemeanDescriptor_t = POINTER(ReducemeanDescriptor) + +def reduce_mean(input, axis, noop_with_empty_axes, keepdims=True): + if axis == None: + if noop_with_empty_axes: + return input + else: + return torch.mean(input, dim=axis, keepdim=keepdims) + return torch.mean(input, dim=axis, keepdim=keepdims) + +def inferShape(x_shape, axis, noop_with_empty_axes, keepdims=False): + if axis == None: + if noop_with_empty_axes: + return x_shape + else: + if keepdims: + return tuple([1] * len(x_shape)) + else: + return tuple([]) + + assert len(axis) <= len(x_shape), "axis out of range" + output_shape = [] + axis = [a if a >= 0 else a + len(x_shape) for a in axis] # 更新 axis 列表中的值 + for a in axis: + assert 0 <= a <= len(x_shape) - 1, "axis out of range" + for i, s in enumerate(x_shape): + if i in axis and keepdims: + output_shape.append(1) + elif i in axis and not keepdims: + continue + else: + output_shape.append(s) + + print(f"output_shape = {output_shape}") + return tuple(output_shape) + +def tuple_to_void_p(py_tuple: Tuple): + array = ctypes.c_int64 * len(py_tuple) + data_array = array(*py_tuple) + return ctypes.cast(data_array, ctypes.c_void_p) + +def test( + lib, + handle, + torch_device, + x_shape, + axes, + dynamic_axes, + noop_with_empty_axes=False, + keepdims=True, + tensor_dtype=torch.float16 +): + print( + f"Testing reducemean on {torch_device} with x_shape:{x_shape} dtype:{tensor_dtype}" + ) + x = torch.randint(0, 10, x_shape, dtype=tensor_dtype, device=torch_device) + print(f"y_shape = {inferShape(x_shape, axes if dynamic_axes == None else dynamic_axes, noop_with_empty_axes, keepdims)}") + y = torch.full(inferShape(x_shape, axes if dynamic_axes == None else dynamic_axes, noop_with_empty_axes, keepdims), float(0), dtype=tensor_dtype, device=torch_device) + print(f"y_shape = {y.shape}") + for i in range(NUM_PRERUN if PROFILE else 1): + ans = reduce_mean(x, axes if dynamic_axes == None else dynamic_axes, noop_with_empty_axes, keepdims) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + _ = reduce_mean(x, axes if dynamic_axes == None else dynamic_axes, noop_with_empty_axes, keepdims) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f"pytorch time: {elapsed :10f}") + x_tensor = to_tensor(x, lib) + y_tensor = to_tensor(y, lib) + descriptor = infiniopReducemeanDescriptor_t() + axe = tuple_to_void_p(axes) if axes != None else None + lenth_axes = c_uint64(len(axes)) if axes != None else c_uint64(0) + dynamic_axes_parm = tuple_to_void_p(dynamic_axes) if dynamic_axes != None else None + lenth_dynamic_axes = c_uint64(len(dynamic_axes)) if dynamic_axes != None else c_uint64(0) + check_error( + lib.infiniopCreateReducemeanDescriptor( + handle, + ctypes.byref(descriptor), + y_tensor.descriptor, + x_tensor.descriptor, + axe, + lenth_axes, + c_bool(keepdims), + c_bool(noop_with_empty_axes), + ) + ) + x_tensor.descriptor.contents.invalidate() + y_tensor.descriptor.contents.invalidate() + for i in range(NUM_PRERUN if PROFILE else 1): + check_error( + lib.infiniopReducemean( + descriptor, + y_tensor.data, + x_tensor.data, + dynamic_axes_parm, + lenth_dynamic_axes, + None, + ) + ) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + check_error( + lib.infiniopReducemean( + descriptor, + y_tensor.data, + x_tensor.data, + dynamic_axes_parm, + lenth_dynamic_axes, + None, + ) + ) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f"lib time: {elapsed :10f}") + #print(f"input_data = {x}") + # print(f"custom op output:{y}") + # print(f"pytorch output:{ans}") + assert torch.allclose(y, ans, atol=0, rtol=1e-3) + + check_error(lib.infiniopDestroyReducemeanDescriptor(descriptor)) + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for x_shape, axes, noop_with_empty_axes, keepdims, dynamic_axes, tensor_dtype in test_cases: + test(lib, handle, "cpu", x_shape, axes, dynamic_axes, noop_with_empty_axes, keepdims, tensor_dtype=tensor_dtype) + print("\n") + #test(lib, handle, "cpu", x_shape, axes, tensor_dtype=torch.float32) + destroy_handle(lib, handle) + +def test_cuda(lib, test_cases): + device = DeviceEnum.DEVICE_CUDA + handle = create_handle(lib, device) + for x_shape, axes, noop_with_empty_axes, keepdims, dynamic_axes, tensor_dtype in test_cases: + test(lib, handle, "cuda", x_shape, axes, dynamic_axes, noop_with_empty_axes, keepdims, tensor_dtype=tensor_dtype) + print("\n") + destroy_handle(lib, handle) + + +if __name__ == "__main__": + test_cases = [ + # dynamic calc test eg + # ((2, 3, 4, 5), [0, 2], False, True, None), + # ((2, 3, 4, 5), [0, 2], False, True, None), + # #(input_shape, axis, noop_with_empty_axes, keepdims, dynamic_axes) + # ((2, 10, 24, 10), [0, 2], False, True, None), + # # stride = + # ((2, 10, 24, 10), [0, 1], False, True, None), + # ((2, 10, 24, 10), [2, 3], False , True, None), + #((1000, 300), [0, 1], False, False, None, torch.float16), + ((30, 5, 20, 100), [0, 1, 2, 3], False, False, None, torch.float16), + ((30000, 1000, 40), [0, 1], False, False, None, torch.float32), + #((1000, 300), [0, 1], False, False, None, torch.float16), + ((2, 2, 5), [0, 1], False, True, None, torch.float32), + ((1000, 200, 500), [0, 1], False, True, None, torch.float16), + ((1000, 200, 50), [0, 1], False, True, None, torch.float32), + ((20, 3, 4, 5), [0, 2], False, False, None, torch.float32), + ((20, 30, 40, 5), [0, 2, 3], False, False, None, torch.float32), + ((200, 3, 40, 5), [0, 3], False, False, None, torch.float32), + # validate attribute noop_with_empty_axes and keepdims + # ((2, 10, 24, 10), None, True, True, None), + # ((2, 10, 24, 10), None, True, False, None), + # ((2, 10, 24, 10), None, False, True, None), + # ((2, 10, 24, 10), None, False, False, None), + # ((2, 3, 4), [0, 1], False, False, None), + #((2, 10, 24, 10), [], True), + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateReducemeanDescriptor.restype = c_int32 + lib.infiniopCreateReducemeanDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopReducemeanDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_void_p, + c_uint64, + c_bool, + c_bool + ] + lib.infiniopReducemean.restype = c_int32 + lib.infiniopReducemean.argtypes = [ + infiniopReducemeanDescriptor_t, + c_void_p, + c_void_p, + c_void_p, + c_uint64, + c_void_p, + ] + lib.infiniopDestroyReducemeanDescriptor.restype = c_int32 + lib.infiniopDestroyReducemeanDescriptor.argtypes = [infiniopReducemeanDescriptor_t] + if args.cpu: + test_cpu(lib, test_cases) + if args.cuda: + test_cuda(lib, test_cases) + print("All tests passed!") \ No newline at end of file diff --git a/operatorspy/tests/reducemin.py b/operatorspy/tests/reducemin.py new file mode 100644 index 00000000..c590b71b --- /dev/null +++ b/operatorspy/tests/reducemin.py @@ -0,0 +1,220 @@ +from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64, c_bool +import ctypes +import sys +import os +import time + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, +) + +from operatorspy.tests.test_utils import get_args +import torch +from typing import Tuple +import numpy as np + +PROFILE = True +NUM_PRERUN = 1 +NUM_ITERATIONS = 1 + +class ReduceminDescriptor(Structure): + _fields_ = [("device", c_int32)] + +infiniopReduceminDescriptor_t = POINTER(ReduceminDescriptor) + +def reduce_min(input, axis, noop_with_empty_axes, keepdims=True): + if axis == None: + if noop_with_empty_axes: + return input + else: + return torch.amin(input, dim=axis, keepdim=keepdims) + return torch.amin(input, dim=axis, keepdim=keepdims) + +def inferShape(x_shape, axis, noop_with_empty_axes, keepdims=False): + if axis == None: + if noop_with_empty_axes: + return x_shape + else: + if keepdims: + return tuple([1] * len(x_shape)) + else: + return tuple([]) + assert len(axis) <= len(x_shape), "axis out of range" + output_shape = [] + axis = [a if a >= 0 else a + len(x_shape) for a in axis] # 更新 axis 列表中的值 + for a in axis: + assert 0 <= a <= len(x_shape) - 1, "axis out of range" + for i, s in enumerate(x_shape): + if i in axis and keepdims: + output_shape.append(1) + elif i in axis and not keepdims: + continue + else: + output_shape.append(s) + + print(f"output_shape = {output_shape}") + return tuple(output_shape) + +def tuple_to_void_p(py_tuple: Tuple): + array = ctypes.c_int64 * len(py_tuple) + data_array = array(*py_tuple) + return ctypes.cast(data_array, ctypes.c_void_p) + +def test( + lib, + handle, + torch_device, + x_shape, + axes, + dynamic_axes, + noop_with_empty_axes=False, + keepdims=True, + tensor_dtype=torch.float16 +): + print( + f"Testing reducemin on {torch_device} with x_shape:{x_shape} dtype:{tensor_dtype}" + ) + x = torch.randn(x_shape, dtype=tensor_dtype, device=torch_device) + print(f"y_shape = {inferShape(x_shape, axes if dynamic_axes == None else dynamic_axes, noop_with_empty_axes, keepdims)}") + y = torch.full(inferShape(x_shape, axes if dynamic_axes == None else dynamic_axes, noop_with_empty_axes, keepdims), float('inf'), dtype=tensor_dtype, device=torch_device) + print(f"y_shape = {y.shape}") + for i in range(NUM_PRERUN if PROFILE else 1): + ans = reduce_min(x, axes if dynamic_axes == None else dynamic_axes, noop_with_empty_axes, keepdims) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + _ = reduce_min(x, axes if dynamic_axes == None else dynamic_axes, noop_with_empty_axes, keepdims) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f"pytorch time: {elapsed :10f}") + x_tensor = to_tensor(x, lib) + y_tensor = to_tensor(y, lib) + descriptor = infiniopReduceminDescriptor_t() + axe = tuple_to_void_p(axes) if axes != None else None + lenth_axes = c_uint64(len(axes)) if axes != None else c_uint64(0) + dynamic_axes_parm = tuple_to_void_p(dynamic_axes) if dynamic_axes != None else None + lenth_dynamic_axes = c_uint64(len(dynamic_axes)) if dynamic_axes != None else c_uint64(0) + check_error( + lib.infiniopCreateReduceminDescriptor( + handle, + ctypes.byref(descriptor), + y_tensor.descriptor, + x_tensor.descriptor, + axe, + lenth_axes, + c_bool(keepdims), + c_bool(noop_with_empty_axes), + ) + ) + x_tensor.descriptor.contents.invalidate() + y_tensor.descriptor.contents.invalidate() + for i in range(NUM_PRERUN if PROFILE else 1): + check_error( + lib.infiniopReducemin( + descriptor, + y_tensor.data, + x_tensor.data, + dynamic_axes_parm, + lenth_dynamic_axes, + None, + ) + ) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + check_error( + lib.infiniopReducemin( + descriptor, + y_tensor.data, + x_tensor.data, + dynamic_axes_parm, + lenth_dynamic_axes, + None, + ) + ) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f"lib time: {elapsed :10f}") + # print(f"custom op output:{y}") + # print(f"pytorch output:{ans}") + assert torch.allclose(y, ans, atol=0, rtol=1e-3) + + check_error(lib.infiniopDestroyReducemaxDescriptor(descriptor)) + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for x_shape, axes, noop_with_empty_axes, keepdims, dynamic_axes, tensor_dtype in test_cases: + test(lib, handle, "cpu", x_shape, axes, dynamic_axes, noop_with_empty_axes, keepdims, tensor_dtype=tensor_dtype) + #test(lib, handle, "cpu", x_shape, axes, tensor_dtype=torch.float32) + destroy_handle(lib, handle) + +def test_cuda(lib, test_cases): + device = DeviceEnum.DEVICE_CUDA + handle = create_handle(lib, device) + for x_shape, axes, noop_with_empty_axes, keepdims, dynamic_axes, tensor_dtype in test_cases: + test(lib, handle, "cuda", x_shape, axes, dynamic_axes, noop_with_empty_axes, keepdims, tensor_dtype=tensor_dtype) + print("\n") + destroy_handle(lib, handle) + +if __name__ == "__main__": + test_cases = [ + # dynamic calc test eg + # ((2, 3, 4, 5), [0, 2], False, True, None), + # ((2, 3, 4, 5), [0, 2], False, True, None), + # #(input_shape, axis, noop_with_empty_axes, keepdims, dynamic_axes) + # ((2, 10, 24, 10), [0, 2], False, True, None), + # # stride = + # ((2, 10, 24, 10), [0, 1], False, True, None), + # ((2, 10, 24, 10), [2, 3], False , True, None), + # ((2, 10, 24, 10), [0, 1, 2, 3], False, True, None), + # # validate attribute noop_with_empty_axes and keepdims + # ((2, 10, 24, 10), None, True, True, None), + # ((2, 10, 24, 10), None, True, False, None), + # ((2, 10, 24, 10), None, False, True, None), + # ((2, 10, 24, 10), None, False, False, None), + # ((2, 3, 4), [0, 1], False, False, None), + # #((2, 10, 24, 10), [], True), + ((2, 1000), [0, 1], False, False, None, torch.float32), + ((2, 2, 5), [0, 1], False, True, None, torch.float32), + ((1000, 200, 500), [0, 1], False, True, None, torch.float16), + ((1000, 200, 50), [0, 1], False, True, None, torch.float32), + ((20, 3, 4, 5), [0, 2], False, False, None, torch.float32), + ((20, 30, 40, 5), [0, 2, 3], False, False, None, torch.float32), + ((200, 3, 40, 5), [0, 3], False, False, None, torch.float32), + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateReduceminDescriptor.restype = c_int32 + lib.infiniopCreateReduceminDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopReduceminDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_void_p, + c_uint64, + c_bool, + c_bool + ] + lib.infiniopReducemin.restype = c_int32 + lib.infiniopReducemin.argtypes = [ + infiniopReduceminDescriptor_t, + c_void_p, + c_void_p, + c_void_p, + c_uint64, + c_void_p, + ] + lib.infiniopDestroyReduceminDescriptor.restype = c_int32 + lib.infiniopDestroyReduceminDescriptor.argtypes = [infiniopReduceminDescriptor_t] + if args.cpu: + test_cpu(lib, test_cases) + if args.cuda: + test_cuda(lib, test_cases) + print("All tests passed!") \ No newline at end of file diff --git a/operatorspy/tests/where.py b/operatorspy/tests/where.py new file mode 100644 index 00000000..f8809433 --- /dev/null +++ b/operatorspy/tests/where.py @@ -0,0 +1,192 @@ +from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64, c_bool +import ctypes +import sys +import os +import time + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, +) + +from operatorspy.tests.test_utils import get_args +import torch +from typing import Tuple +import numpy as np + +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + +class WhereDescriptor(Structure): + _fields_ = [("device", c_int32)] + +infiniopWhereDescriptor_t = POINTER(WhereDescriptor) + +def where(condition, x, y): + return torch.where(condition, x, y) + + +def tuple_to_void_p(py_tuple: Tuple): + array = ctypes.c_int64 * len(py_tuple) + data_array = array(*py_tuple) + return ctypes.cast(data_array, ctypes.c_void_p) + +def inferShape(x_shape, y_shape): + ndim_x = len(x_shape) + ndim_y = len(y_shape) + ndim = max(ndim_x, ndim_y) + output_shape = [] + + for i in range(-1, -ndim-1, -1): + dim_x = x_shape[i] if i >= -ndim_x else 1 + dim_y = y_shape[i] if i >= -ndim_y else 1 + + if dim_x != dim_y: + if dim_x != 1 and dim_y != 1: + raise ValueError(f"Shapes {x_shape} and {y_shape} cannot be broadcast together") + + output_dim = max(dim_x, dim_y) + output_shape.insert(0, output_dim) + + return tuple(output_shape) + + +def test( + lib, + handle, + torch_device, + condition_shape, + src1_shape, + src2_shape, + tensor_dtype=torch.float16 +): + print( + f"Testing where on {torch_device} with condition_shape:{condition_shape} dtype:{tensor_dtype}" + ) + condition = torch.randint(0, 2, condition_shape, dtype=torch.uint8).to(torch_device) + src1 = torch.randn(src1_shape, dtype=tensor_dtype, device=torch_device) + src2 = torch.randn(src2_shape, dtype=tensor_dtype, device=torch_device) + output = torch.randn(inferShape(inferShape(src1_shape, src2_shape), condition_shape), dtype=tensor_dtype, device=torch_device) + + for i in range(NUM_PRERUN if PROFILE else 1): + ans = where(condition, src1, src2) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + _ = where(condition, src1, src2) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f"pytorch time: {elapsed :10f}") + src1_tensor = to_tensor(src1, lib) + src2_tensor = to_tensor(src2, lib) + output_tensor = to_tensor(output, lib) + condition_tensor = to_tensor(condition, lib) + descriptor = infiniopWhereDescriptor_t() + check_error( + lib.infiniopCreateWhereDescriptor( + handle, + ctypes.byref(descriptor), + output_tensor.descriptor, + src1_tensor.descriptor, + src2_tensor.descriptor, + condition_tensor.descriptor, + ) + ) + src1_tensor.descriptor.contents.invalidate() + src2_tensor.descriptor.contents.invalidate() + output_tensor.descriptor.contents.invalidate() + condition_tensor.descriptor.contents.invalidate() + for i in range(NUM_PRERUN if PROFILE else 1): + check_error( + lib.infiniopWhere( + descriptor, + output_tensor.data, + src1_tensor.data, + src2_tensor.data, + condition_tensor.data, + None, + ) + ) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + check_error( + lib.infiniopWhere( + descriptor, + output_tensor.data, + src1_tensor.data, + src2_tensor.data, + condition_tensor.data, + None, + ) + ) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f"lib time: {elapsed :10f}") + assert torch.allclose(output, ans, atol=0, rtol=0) + check_error(lib.infiniopDestroyWhereDescriptor(descriptor)) + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for condition_shape, src1_shape, src2_shape, tensor_dtype in test_cases: + test(lib, handle, "cpu", condition_shape, src1_shape, src2_shape, tensor_dtype=tensor_dtype) + print("\n") + destroy_handle(lib, handle) + +def test_cuda(lib, test_cases): + device = DeviceEnum.DEVICE_CUDA + handle = create_handle(lib, device) + for condition_shape, src1_shape, src2_shape, tensor_dtype in test_cases: + test(lib, handle, "cuda", condition_shape, src1_shape, src2_shape, tensor_dtype=tensor_dtype) + print("\n") + destroy_handle(lib, handle) + + +if __name__ == "__main__": + test_cases = [ + ((2, 16), (2, 16), (2, 16), torch.float32), + ((2, 3, 1, 1), (1, 4, 5), (2, 3, 4, 5), torch.float32), + ((3, 1), (3, 4), (1, 4), torch.float32), + ((1,), (3, 4), (3, 4), torch.float32), + ((2, 1, 3), (1, 4, 3), (2, 4, 1), torch.float32), + + ((2, 16), (2, 16), (2, 16), torch.float16), + ((2, 3, 1, 1), (1, 4, 5), (2, 3, 4, 5), torch.float16), + ((3, 1), (3, 4), (1, 4), torch.float16), + ((1,), (3, 4), (3, 4), torch.float16), + ((2, 1, 3), (1, 4, 3), (2, 4, 1), torch.float16), + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateWhereDescriptor.restype = c_int32 + lib.infiniopCreateWhereDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopWhereDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t + ] + lib.infiniopWhere.restype = c_int32 + lib.infiniopWhere.argtypes = [ + infiniopWhereDescriptor_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p + ] + lib.infiniopDestroyWhereDescriptor.restype = c_int32 + lib.infiniopDestroyWhereDescriptor.argtypes = [infiniopWhereDescriptor_t] + if args.cpu: + test_cpu(lib, test_cases) + if args.cuda: + test_cuda(lib, test_cases) + print("All tests passed!") \ No newline at end of file diff --git a/src/ops/clip/cpu/clip_cpu.cc b/src/ops/clip/cpu/clip_cpu.cc new file mode 100644 index 00000000..558f3ead --- /dev/null +++ b/src/ops/clip/cpu/clip_cpu.cc @@ -0,0 +1,80 @@ +#include "clip_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "../../utils.h" + +infiniopStatus_t cpuCreateClipDescriptor(infiniopHandle_t handle, + ClipCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t y, + float* min, + float* max + ){ + + if (x->dt != F16 && x->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (x->ndim != y->ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (x->dt != y->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (!is_contiguous(x) || !is_contiguous(y)) { + return STATUS_BAD_TENSOR_STRIDES; + } + uint64_t element_num = 1; + for (uint64_t i = 0; i < x->ndim; i++) { + element_num *= x->shape[i]; + } + bool has_min = min != nullptr; + bool has_max = max != nullptr; + float min_ = has_min ? *min : std::numeric_limits::lowest(); + float max_ = has_max ? *max : std::numeric_limits::max(); + *desc_ptr = new ClipCpuDescriptor{ + DevCpu, + x->dt, + min_, + max_, + has_min, + has_max, + element_num + }; + return STATUS_SUCCESS; +} + +template +infiniopStatus_t clip_cpu(ClipCpuDescriptor_t desc, + void const *x, + void *y){ + auto x_ = reinterpret_cast(x); + auto y_ = reinterpret_cast(y); + for (uint64_t i = 0; i < desc->element_num; i++) { + if constexpr (std::is_same::value){ + float x_f = f16_to_f32(x_[i]); + x_f = desc->has_min ? std::max(x_f, desc->min) : x_f; + x_f = desc->has_max ? std::min(x_f, desc->max) : x_f; + y_[i] = f32_to_f16(x_f); + } + else{ + y_[i] = std::min(std::max(x_[i], desc->min), desc->max); + } + } + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuClip(ClipCpuDescriptor_t desc, + void const*x, + void *y, + void *stream){ + if (desc->dtype == F16) { + return clip_cpu(desc, x, y); + } + if (desc->dtype == F32) { + return clip_cpu(desc, x, y); + } + return STATUS_BAD_TENSOR_DTYPE; +} +infiniopStatus_t cpuDestroyClipDescriptor(ClipCpuDescriptor_t desc){ + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/clip/cpu/clip_cpu.h b/src/ops/clip/cpu/clip_cpu.h new file mode 100644 index 00000000..9bc23753 --- /dev/null +++ b/src/ops/clip/cpu/clip_cpu.h @@ -0,0 +1,32 @@ +#ifndef __CPU_CLIP_H__ +#define __CPU_CLIP_H__ + +#include "operators.h" +struct ClipCpuDescriptor { + Device device; + DT dtype; + float min; + float max; + bool has_min; + bool has_max; + uint64_t element_num; +}; + +typedef struct ClipCpuDescriptor *ClipCpuDescriptor_t; + +infiniopStatus_t cpuCreateClipDescriptor(infiniopHandle_t handle, + ClipCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t y, + float* min, + float* max + ); + +infiniopStatus_t cpuClip(ClipCpuDescriptor_t desc, + void const *x, + void *y, + void *stream); + +infiniopStatus_t cpuDestroyClipDescriptor(ClipCpuDescriptor_t desc); + +#endif diff --git a/src/ops/clip/cuda/clip_cuda.cc b/src/ops/clip/cuda/clip_cuda.cc new file mode 100644 index 00000000..cd3a48c3 --- /dev/null +++ b/src/ops/clip/cuda/clip_cuda.cc @@ -0,0 +1,50 @@ +#include "clip_cuda.h" +#include "../../../devices/cuda/common_cuda.h" +#include "../../utils.h" + +infiniopStatus_t cudaCreateClipDescriptor(CudaHandle_t handle, + ClipCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t y, + float *min, + float *max + ) { + if (x->dt != F16 && x->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (x->ndim != y->ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (x->dt != y->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (!is_contiguous(x) || !is_contiguous(y)) { + return STATUS_BAD_TENSOR_STRIDES; + } + uint64_t element_num = 1; + for (uint64_t i = 0; i < x->ndim; i++) { + element_num *= x->shape[i]; + } + uint64_t ndim = y->ndim; + bool has_min = min != nullptr; + bool has_max = max != nullptr; + float min_ = has_min ? *min : std::numeric_limits::lowest(); + float max_ = has_max ? *max : std::numeric_limits::max(); + *desc_ptr = new ClipCudaDescriptor{ + DevNvGpu, + x->dt, + ndim, + element_num, + min_, + max_, + has_min, + has_max + }; + return STATUS_SUCCESS; +} + + +infiniopStatus_t cudaDestroyClipDescriptor(ClipCudaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/clip/cuda/clip_cuda.cu b/src/ops/clip/cuda/clip_cuda.cu new file mode 100644 index 00000000..ceb8663f --- /dev/null +++ b/src/ops/clip/cuda/clip_cuda.cu @@ -0,0 +1,97 @@ +#include "../../../devices/cuda/cuda_handle.h" +#include "../../utils.h" +#include "clip_cuda.h" +#include + +#define WARP_SIZE 32 +#define LDST128BITS(value) (reinterpret_cast(&(value))[0]) +#define FLOAT4(value) (reinterpret_cast(&(value))[0]) + +#define LDST128BITS_CONST(value) (reinterpret_cast(&(value))[0]) +#define FLOAT4_CONST(value) (reinterpret_cast(&(value))[0]) + +__global__ void clip_f32x4_kernel(const float *a, float *b, float max_value, float min_value, int N){ + int idx = 4 * (blockDim.x * blockIdx.x + threadIdx.x); + if (idx < N) { + int remaining = N - idx; + float4 reg_a, reg_b; + if (remaining >= 4) { + reg_a = FLOAT4_CONST(a[idx]); + } else { + reg_a.x = a[idx]; + reg_a.y = (remaining >= 2) ? a[idx + 1] : 0; + reg_a.z = (remaining >= 3) ? a[idx + 2] : 0; + reg_a.w = 0; + } + reg_b.x = fminf(fmaxf(reg_a.x, min_value), max_value); + reg_b.y = fminf(fmaxf(reg_a.y, min_value), max_value); + reg_b.z = fminf(fmaxf(reg_a.z, min_value), max_value); + reg_b.w = fminf(fmaxf(reg_a.w, min_value), max_value); + if (remaining >= 4) { + FLOAT4(b[idx]) = reg_b; + } else { + if (remaining >= 1) b[idx] = reg_b.x; + if (remaining >= 2) b[idx + 1] = reg_b.y; + if (remaining >= 3) b[idx + 2] = reg_b.z; + } + } +} + + +__global__ void clip_f16x8_pack_kernel(const half *a, half *b, float max_value, float min_value, int N){ + int idx = 8 * (blockDim.x * blockIdx.x + threadIdx.x); + if (idx >= N) return; + const half min_half = __float2half(min_value); + const half max_half = __float2half(max_value); + half pack_a[8], pack_b[8]; + if (idx + 7 < N) { + LDST128BITS(pack_a[0]) = LDST128BITS_CONST(a[idx]); + } else { + for (int i = 0; i < 8 && (idx + i) < N; i++) { + pack_a[i] = a[idx + i]; + } + } + #pragma unroll + for (int i = 0; i < 8; i++) + { + pack_b[i] = __hlt(pack_a[i], min_half) ? min_half : pack_a[i]; + pack_b[i] = __hgt(pack_a[i], max_half) ? max_half : pack_a[i]; + } + if (idx + 7 < N) { + LDST128BITS(b[idx]) = LDST128BITS(pack_b[0]); + } else { + for (int i = 0; i < 8 && (idx + i) < N; i++) { + b[idx + i] = pack_b[i]; + } + } +} +template +infiniopStatus_t clip_nv_gpu( + ClipCudaDescriptor_t desc, + void const *x, + void *y, + int per_thread_element, + void* stream) { + uint64_t N = desc->element_num; + dim3 block(256 / per_thread_element); + dim3 grid((N + 256 - 1) / 256); + if constexpr(std::is_same::value){ + clip_f32x4_kernel<<>>(reinterpret_cast(x), reinterpret_cast(y), desc->max, desc->min, N); + }else{ + clip_f16x8_pack_kernel<<>>(reinterpret_cast(x), reinterpret_cast(y), desc->max, desc->min, N); + } + return STATUS_SUCCESS; +} + +infiniopStatus_t cudaClip(ClipCudaDescriptor_t desc, + void const *x, + void *y, + void *stream){ + if (desc->dtype == F16) { + return clip_nv_gpu(desc, x, y, 8, stream); + } + if (desc->dtype == F32) { + return clip_nv_gpu(desc, x, y, 4, stream); + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/clip/cuda/clip_cuda.h b/src/ops/clip/cuda/clip_cuda.h new file mode 100644 index 00000000..b56911ed --- /dev/null +++ b/src/ops/clip/cuda/clip_cuda.h @@ -0,0 +1,37 @@ +#ifndef __CUDA_CLIP_H__ +#define __CUDA_CLIP_H__ + +#include "../../../devices/cuda/cuda_handle.h" +#include "operators.h" +#include + +typedef struct ClipCudaDescriptor { + Device device; + DT dtype; + uint64_t ndim; + uint64_t element_num; + float min; + float max; + bool has_min; + bool hax_max; +} ClipCudaDescriptor; + +typedef struct ClipCudaDescriptor *ClipCudaDescriptor_t; + +infiniopStatus_t cudaCreateClipDescriptor(CudaHandle_t handle, + ClipCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t y, + float* min, + float* max + ); + + +infiniopStatus_t cudaClip(ClipCudaDescriptor_t desc, + void const *x, + void *y, + void *stream); + +infiniopStatus_t cudaDestroyClipDescriptor(ClipCudaDescriptor_t desc); + +#endif// __CUDA_MATMUL_H__ diff --git a/src/ops/clip/operator.cc b/src/ops/clip/operator.cc new file mode 100644 index 00000000..271b9b96 --- /dev/null +++ b/src/ops/clip/operator.cc @@ -0,0 +1,63 @@ +#include "../utils.h" +#include "operators.h" +#include "ops/clip/clip.h" + +#ifdef ENABLE_CPU +#include "cpu/clip_cpu.h" +#endif + +#ifdef ENABLE_NV_GPU +#include "cuda/clip_cuda.h" +#endif + + +__C infiniopStatus_t infiniopCreateClipDescriptor( + infiniopHandle_t handle, + infiniopClipDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t y, + float* min, + float* max + ) { + switch (handle->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuCreateClipDescriptor(handle, (ClipCpuDescriptor_t *) desc_ptr, x, y, min, max); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaCreateClipDescriptor((CudaHandle_t) handle, (ClipCudaDescriptor_t *) desc_ptr, x, y, min, max); + } +#endif + } + return STATUS_BAD_DEVICE; +} + + +__C infiniopStatus_t infiniopClip(infiniopClipDescriptor_t desc, void const *x, void *y, void *stream) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuClip((ClipCpuDescriptor_t) desc, x, y, stream); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: + return cudaClip((ClipCudaDescriptor_t) desc, x, y, stream); +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopDestroyClipDescriptor(infiniopClipDescriptor_t desc){ + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuDestroyClipDescriptor((ClipCpuDescriptor_t) desc); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: + return cudaDestroyClipDescriptor((ClipCudaDescriptor_t) desc); +#endif + } + return STATUS_BAD_DEVICE; +} \ No newline at end of file diff --git a/src/ops/gather/cpu/gather_cpu.cc b/src/ops/gather/cpu/gather_cpu.cc new file mode 100644 index 00000000..51264094 --- /dev/null +++ b/src/ops/gather/cpu/gather_cpu.cc @@ -0,0 +1,121 @@ +#include "gather_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "../../utils.h" + +infiniopStatus_t cpuCreateGatherDescriptor(infiniopHandle_t handle, + GatherCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t indices, + int64_t axis + ){ + if (y->dt != x->dt){ + return STATUS_BAD_TENSOR_DTYPE; + } + if (y->dt != F16 && y->dt != F32){ + return STATUS_BAD_TENSOR_DTYPE; + } + if (!is_contiguous(y) || !is_contiguous(x)){ + return STATUS_BAD_TENSOR_STRIDES; + } + if (axis < 0 || axis >= x->ndim){ + return STATUS_BAD_PARAM; + } + uint64_t *dst_shape = new uint64_t[y->ndim]; + uint64_t *src_shape = new uint64_t[x->ndim]; + uint64_t *indices_shape = new uint64_t[indices->ndim]; + + memcpy(dst_shape, y->shape, y->ndim * sizeof(uint64_t)); + memcpy(indices_shape, indices->shape, indices->ndim * sizeof(uint64_t)); + memcpy(src_shape, x->shape, x->ndim * sizeof(uint64_t)); + + *desc_ptr = new GatherCpuDescriptor{ + DevCpu, + y->dt, + indices->dt, + dst_shape, + src_shape, + indices_shape, + indices->ndim, + x->ndim, + y->ndim, + axis + }; + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuDestroyGatherDescriptor(GatherCpuDescriptor_t desc){ + delete[] desc->dst_shape; + delete[] desc->src_shape; + delete[] desc->indices_shape; + delete desc; + return STATUS_SUCCESS; +} + +template +infiniopStatus_t gather_cpu(GatherCpuDescriptor_t desc, + void const *x, + void const *indices, + void *y) +{ + auto *src_data = reinterpret_cast(x); + auto *indices_data = reinterpret_cast(indices); // [!code ++] + auto *dst_data = reinterpret_cast(y); + uint64_t indices_element_count = 1; + for (uint64_t i = 0; i < desc->indices_ndim; ++i) { + indices_element_count *= desc->indices_shape[i]; + } + + const uint64_t axis = desc->axis; + uint64_t src_outer_dim = 1; + for (uint64_t i = 0; i < axis; ++i) { + src_outer_dim *= desc->src_shape[i]; + } + + uint64_t src_inner_dim = 1; + for (uint64_t i = axis + 1; i < desc->src_ndim; ++i) { + src_inner_dim *= desc->src_shape[i]; + } + for (uint64_t outer = 0; outer < src_outer_dim; ++outer) { + for (uint64_t idx = 0; idx < indices_element_count; ++idx) { + const int64_t index_val = indices_data[idx]; // [!code ++] + const uint64_t src_offset = + outer * desc->src_shape[axis] * src_inner_dim + + index_val * src_inner_dim; + const uint64_t dst_offset = + outer * indices_element_count * src_inner_dim + + idx * src_inner_dim; + memcpy( + dst_data + dst_offset, + src_data + src_offset, + sizeof(Tdata) * src_inner_dim + ); + } + } + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuGather(GatherCpuDescriptor_t desc, + void const *x, + void const *indices, + void *y, + void *stream){ + if (desc->dtype == F16){ + if (desc->indices_dtype == I32){ + return gather_cpu(desc, x, indices, y); + } + else if (desc->indices_dtype == I64){ + return gather_cpu(desc, x, indices, y); + } + } + if (desc->dtype == F32){ + if (desc->indices_dtype == I32){ + return gather_cpu(desc, x, indices, y); + } + else if (desc->indices_dtype == I64){ + return gather_cpu(desc, x, indices, y); + } + } + return STATUS_SUCCESS; +} + diff --git a/src/ops/gather/cpu/gather_cpu.h b/src/ops/gather/cpu/gather_cpu.h new file mode 100644 index 00000000..41243cae --- /dev/null +++ b/src/ops/gather/cpu/gather_cpu.h @@ -0,0 +1,36 @@ +#ifndef __CPU_GATHER_H__ +#define __CPU_GATHER_H__ + +#include "operators.h" +struct GatherCpuDescriptor { + Device device; + DT dtype; + DT indices_dtype; + uint64_t const *dst_shape; + uint64_t const *src_shape; + uint64_t const *indices_shape; + uint64_t indices_ndim; + uint64_t src_ndim; + uint64_t dst_ndim; + int64_t axis; +}; + +typedef struct GatherCpuDescriptor *GatherCpuDescriptor_t; + +infiniopStatus_t cpuCreateGatherDescriptor(infiniopHandle_t handle, + GatherCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t indices, + int64_t axis + ); + +infiniopStatus_t cpuGather(GatherCpuDescriptor_t desc, + void const *data, + void const *indices, + void *dst, + void *stream); + +infiniopStatus_t cpuDestroyGatherDescriptor(GatherCpuDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/src/ops/gather/cuda/gather_cuda.cc b/src/ops/gather/cuda/gather_cuda.cc new file mode 100644 index 00000000..a5f903af --- /dev/null +++ b/src/ops/gather/cuda/gather_cuda.cc @@ -0,0 +1,57 @@ +#include "gather_cuda.h" +#include "../../../devices/cuda/common_cuda.h" +#include "../../utils.h" + +infiniopStatus_t cudaCreateGatherDescriptor(CudaHandle_t handle, + GatherCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t indices, + int64_t axis + ){ + if (y->dt != x->dt){ + return STATUS_BAD_TENSOR_DTYPE; + } + if (y->dt != F16 && y->dt != F32){ + return STATUS_BAD_TENSOR_DTYPE; + } + if (!is_contiguous(y) || !is_contiguous(x)){ + return STATUS_BAD_TENSOR_STRIDES; + } + if (axis < 0 || axis >= x->ndim){ + return STATUS_BAD_PARAM; + } + int otherDims = 1; + for (int i = 0; i < axis; i++){ + otherDims *= static_cast(x->shape[i]); + } + for (int i = axis + 1; i < x->ndim; i++){ + otherDims *= static_cast(x->shape[i]); + } + int stride = 1; + for (int i = axis + 1; i < x->ndim; i++){ + stride *= static_cast(x->shape[i]); + } + int indices_size = 1; + for (int i = 0; i < indices->ndim; i++){ + indices_size *= static_cast(indices->shape[i]); + } + int dim_size = static_cast(x->shape[axis]); + *desc_ptr = new GatherCudaDescriptor{ + DevNvGpu, + x->dt, + indices->dt, + axis, + otherDims, + x->ndim, + y->ndim, + dim_size, + indices_size, + stride + }; + return STATUS_SUCCESS; +} + +infiniopStatus_t cudaDestroyGatherDescriptor(GatherCudaDescriptor_t desc){ + return STATUS_SUCCESS; +} \ No newline at end of file diff --git a/src/ops/gather/cuda/gather_cuda.cu b/src/ops/gather/cuda/gather_cuda.cu new file mode 100644 index 00000000..9f794a51 --- /dev/null +++ b/src/ops/gather/cuda/gather_cuda.cu @@ -0,0 +1,183 @@ +#include "../../../devices/cuda/cuda_handle.h" +#include "../../utils.h" +#include "gather_cuda.h" +#include + +#define LDST128BITS_CONST(value) (reinterpret_cast(&(value))[0]) +#define LDST128BITS(value) (reinterpret_cast(&(value))[0]) +template +__global__ void gather_kernel( + const T* input, + const IndexType* indices, + T* output, + int otherDims, + int dim_size, + int stride, + int indices_size +) { + extern __shared__ char shared_mem[]; + IndexType* sharedIndices = (IndexType*)shared_mem; + if constexpr(use_shared) { + for (int i = threadIdx.x; i < indices_size; i += blockDim.x) { + if (i < indices_size) { + sharedIndices[i] = indices[i]; + } + } + __syncthreads(); + } + constexpr int ITEMS_PER_THREAD = 4; + int tid_base = (blockIdx.x * blockDim.x + threadIdx.x) * ITEMS_PER_THREAD; + int total_elements = otherDims * indices_size; + int stride_indices = stride * indices_size; + + #pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; ++i) { + int tid = tid_base + i; + if (tid >= total_elements) break; + int outer_idx = tid / stride_indices; + int remaining = tid - outer_idx * stride_indices; + int indices_idx = remaining / stride; + int inner_idx = remaining - indices_idx * stride; + IndexType gather_idx; + if constexpr(use_shared) { + gather_idx = indices_size == 1 ? sharedIndices[0] : sharedIndices[indices_idx]; + } else { + gather_idx = indices_size == 1 ? indices[0] : indices[indices_idx]; + } + if (gather_idx >= 0 && gather_idx < dim_size) { + int outer_offset = outer_idx * stride_indices; + int indices_offset = indices_idx * stride; + int input_offset = outer_idx * stride * dim_size + gather_idx * stride; + output[outer_offset + indices_offset + inner_idx] = input[input_offset + inner_idx]; + } + } +} +template +__global__ void gather_vectorized_kernel( + const T* input, + const IndexType* indices, + T* output, + int otherDims, + int dim_size, + int stride, + int indices_size +) { + extern __shared__ char shared_mem[]; + IndexType* sharedIndices = (IndexType*)shared_mem; + + if constexpr(use_shared) { + for (int i = threadIdx.x; i < indices_size; i += blockDim.x) { + if (i < indices_size) { + sharedIndices[i] = indices[i]; + } + } + __syncthreads(); + } + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride_in_vec4; + int total_vectors; + if constexpr(std::is_same::value){ + stride_in_vec4 = stride / 4; + total_vectors = otherDims * indices_size / 4; + } else { + stride_in_vec4 = stride / 8; + total_vectors = otherDims * indices_size / 8; + } + if (tid < total_vectors) { + int outer_idx = tid / (stride_in_vec4 * indices_size); + int remaining = tid - (outer_idx * stride_in_vec4 * indices_size); + int indices_idx = remaining / stride_in_vec4; + int inner_idx = remaining - (indices_idx * stride_in_vec4); + IndexType gather_idx; + if constexpr(use_shared) { + gather_idx = indices_size == 1 ? sharedIndices[0] : sharedIndices[indices_idx]; + } else { + gather_idx = indices_size == 1 ? indices[0] : indices[indices_idx]; + } + if (gather_idx >= 0 && gather_idx < dim_size) { + int input_idx = (outer_idx * stride_in_vec4 * dim_size + gather_idx * stride_in_vec4 + inner_idx) * 4; + int output_idx = (outer_idx * stride_in_vec4 * indices_size + indices_idx * stride_in_vec4 + inner_idx) * 4; + LDST128BITS(output[output_idx]) = LDST128BITS_CONST(input[input_idx]); + } + } +} + +template +infiniopStatus_t gather_nv_gpu(GatherCudaDescriptor_t desc, + void const *x, + void *y, + void const *indices, + void *stream){ + + const size_t sharedMemSize = desc->indices_size <= 1024 ? desc->indices_size * sizeof(IndexType) : 0; + if (desc->stride % 4 == 0 && desc->stride >= 4){ + const int block_size = 128; + const int total_vectors = desc->otherDims * desc->indices_size / 4; + const int gridSize = (total_vectors + block_size - 1) / block_size; + if (sharedMemSize == 0) gather_vectorized_kernel<<>>( + reinterpret_cast(x), + reinterpret_cast(indices), + reinterpret_cast(y), + desc->otherDims, + desc->dim_size, + desc->stride, + desc->indices_size + ); + else gather_vectorized_kernel<<>>( + reinterpret_cast(x), + reinterpret_cast(indices), + reinterpret_cast(y), + desc->otherDims, + desc->dim_size, + desc->stride, + desc->indices_size + ); + } else { + int block_size = 128; + const int total_elements = desc->otherDims * desc->indices_size; + const int gridSize = (total_elements + block_size * 4 - 1) / (block_size * 4); + if (sharedMemSize == 0) gather_kernel<<>>( + reinterpret_cast(x), + reinterpret_cast(indices), + reinterpret_cast(y), + desc->otherDims, + desc->dim_size, + desc->stride, + desc->indices_size + ); + else gather_kernel<<>>( + reinterpret_cast(x), + reinterpret_cast(indices), + reinterpret_cast(y), + desc->otherDims, + desc->dim_size, + desc->stride, + desc->indices_size + ); + } + return STATUS_SUCCESS; +} + +infiniopStatus_t cudaGather(GatherCudaDescriptor_t desc, + const void *x, + const void *indices, + void *y, + void *stream){ + if (desc->dtype == F32){ + if (desc->indices_dtype == I32){ + return gather_nv_gpu(desc, x, y, indices, stream); + } + else if (desc->indices_dtype == I64){ + return gather_nv_gpu(desc, x, y, indices, stream); + } + } + if (desc->dtype == F16){ + if (desc->indices_dtype == I32){ + return gather_nv_gpu(desc, x, y, indices, stream); + } + else if (desc->indices_dtype == I64){ + return gather_nv_gpu(desc, x, y, indices, stream); + } + } + return STATUS_BAD_TENSOR_DTYPE; +} \ No newline at end of file diff --git a/src/ops/gather/cuda/gather_cuda.h b/src/ops/gather/cuda/gather_cuda.h new file mode 100644 index 00000000..562c5bbc --- /dev/null +++ b/src/ops/gather/cuda/gather_cuda.h @@ -0,0 +1,41 @@ +#ifndef __CUDA_GATHER_H__ +#define __CUDA_GATHER_H__ + +#include "../../../devices/cuda/cuda_handle.h" +#include "operators.h" +#include + +typedef struct GatherCudaDescriptor { + Device device; + DT dtype; + DT indices_dtype; + int64_t axis; + int otherDims; + uint64_t input_ndim; + uint64_t output_ndim; + int dim_size; + int indices_size; + int stride; + +} GatherCudaDescriptor; + +typedef struct GatherCudaDescriptor *GatherCudaDescriptor_t; + +infiniopStatus_t cudaCreateGatherDescriptor(CudaHandle_t handle, + GatherCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t indices, + int64_t axis + ); + + +infiniopStatus_t cudaGather(GatherCudaDescriptor_t desc, + void const *x, + void const *indices, + void *y, + void *stream); + +infiniopStatus_t cudaDestroyGatherDescriptor(GatherCudaDescriptor_t desc); + +#endif// __CUDA_MATMUL_H__ diff --git a/src/ops/gather/operator.cc b/src/ops/gather/operator.cc new file mode 100644 index 00000000..8d51c43c --- /dev/null +++ b/src/ops/gather/operator.cc @@ -0,0 +1,60 @@ +#include "../utils.h" +#include "operators.h" +#include "ops/gather/gather.h" + +#ifdef ENABLE_CPU +#include "cpu/gather_cpu.h" +#endif + +#ifdef ENABLE_NV_GPU +#include "cuda/gather_cuda.h" +#endif + +__C infiniopStatus_t infiniopCreateGatherDescriptor( + infiniopHandle_t handle, + infiniopGatherDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t indices, + int64_t axis + ) { + switch (handle->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuCreateGatherDescriptor(handle, (GatherCpuDescriptor_t *) desc_ptr, y, x, indices, axis); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaCreateGatherDescriptor((CudaHandle_t) handle, (GatherCudaDescriptor_t *) desc_ptr, y, x, indices, axis); + } +#endif + } + return STATUS_BAD_DEVICE; +} +__C infiniopStatus_t infiniopGather(infiniopGatherDescriptor_t desc, void const *x, void const *indices, void *y, void *stream) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuGather((GatherCpuDescriptor_t) desc, x, indices, y, stream); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: + return cudaGather((GatherCudaDescriptor_t) desc, x, indices, y, stream); +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopDestroyGatherDescriptor(infiniopGatherDescriptor_t desc){ + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuDestroyGatherDescriptor((GatherCpuDescriptor_t) desc); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: + return cudaDestroyGatherDescriptor((GatherCudaDescriptor_t) desc); +#endif + } + return STATUS_BAD_DEVICE; +} \ No newline at end of file diff --git a/src/ops/reduce/cpu/reduce_cpu.cc b/src/ops/reduce/cpu/reduce_cpu.cc new file mode 100644 index 00000000..74ad723c --- /dev/null +++ b/src/ops/reduce/cpu/reduce_cpu.cc @@ -0,0 +1,304 @@ +#include "reduce_cpu.h" +#include "../../utils.h" +#include +infiniopStatus_t cpuCreateReduceDescriptor(infiniopHandle_t handle, + ReduceCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int64_t const *axes, + uint64_t n, + int reduce_type, + bool noop_with_empty_axes, + bool keepdims) { + uint64_t ndim = y->ndim; + uint64_t x_ndim = x->ndim; + if (reduce_type > 2){ + return STATUS_BAD_PARAM; + } + if (y->dt != x->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (y->dt != F16 && y->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (!is_contiguous(y) || !is_contiguous(x)) { + return STATUS_BAD_TENSOR_STRIDES; + } + uint64_t *x_shape = new uint64_t[x_ndim]; + uint64_t *y_shape = new uint64_t[ndim]; + int64_t *x_strides = new int64_t[x_ndim]; + int64_t *y_strides = new int64_t[ndim]; + uint64_t y_size = 1; + for (uint64_t i = 0; i < ndim; i++) { + y_size *= y->shape[i]; + } + memcpy(y_shape, y->shape, ndim * sizeof(uint64_t)); + memcpy(y_strides, y->strides, ndim * sizeof(int64_t)); + memcpy(x_strides, x->strides, x_ndim * sizeof(int64_t)); + memcpy(x_shape, x->shape, x_ndim * sizeof(uint64_t)); + + if (axes != nullptr && n > 0) { + bool is_axes_static = true; + for (uint64_t i = 0; i < n; i++) { + if (axes[i] >= x->ndim) { + return STATUS_BAD_PARAM; // 轴索引越界 + } + } + std::unordered_set axes_set; + for (uint64_t i = 0; i < n; i++) { + if (axes_set.find(axes[i]) != axes_set.end()) { + return STATUS_BAD_PARAM; // 轴重复 + } + axes_set.insert(axes[i]); + } + std::vector axes_vec(axes_set.begin(), axes_set.end()); + std::sort(axes_vec.begin(), axes_vec.end()); + int64_t* unique_axes = new int64_t[axes_vec.size()]; + std::copy(axes_vec.begin(), axes_vec.end(), unique_axes); + uint64_t *reduce_axes_stride = new uint64_t[axes_vec.size()]; + for (int i = axes_vec.size() - 1; i >= 0; i--){ + reduce_axes_stride[i] = 1; + } + for (int i = axes_vec.size() - 2; i >= 0; i--){ + reduce_axes_stride[i] = reduce_axes_stride[i + 1] * x->shape[axes_vec[i + 1]]; + } + int reduce_element_num = 1; + for (auto axis: axes_vec) { + reduce_element_num *= x->shape[axis]; + } + *desc_ptr = new ReduceCpuDescriptor{ + DevCpu, + y->dt, + ndim, + x_ndim, + x_shape, + unique_axes, + x_strides, + reduce_axes_stride, + y_strides, + y_size, + reduce_type, + reduce_element_num, + n, + is_axes_static, + noop_with_empty_axes, + true, + keepdims + }; + } + else { + *desc_ptr = new ReduceCpuDescriptor{ + DevCpu, + y->dt, + ndim, + x_ndim, + x_shape, + nullptr, + x_strides, + nullptr, + y_strides, + y_size, + reduce_type, + 1, + 0, + false, + noop_with_empty_axes, + false, + keepdims + }; + } + return STATUS_SUCCESS; +} + + +infiniopStatus_t cpuDestroyReduceDescriptor(ReduceCpuDescriptor_t desc) { + delete[] desc->x_shape; + delete[] desc->axes; + delete[] desc->x_strides; + delete[] desc->reduce_axes_stride; + delete desc; + return STATUS_SUCCESS; +} +template +infiniopStatus_t reduce_cpu(ReduceCpuDescriptor_t desc, + void *y, + void const *x){ + auto y_ = reinterpret_cast(y); + auto x_ = reinterpret_cast(x); + auto input_strides = desc->x_strides; + auto reduce_axes_stride = desc->reduce_axes_stride; + auto num_axes = desc->axes_num; + auto axes = desc->axes; + auto ndim = desc->ndim; + auto x_shape = desc->x_shape; + auto y_strides = desc->y_strides; + int i, j; + int indices[desc->x_ndim]; + std::vector non_reduce_axes; + for (int j = 0; j < desc->x_ndim; ++j) { + if (!std::binary_search(axes, axes + num_axes, j)) { + non_reduce_axes.push_back(j); + } + } + for (i = 0; i < desc->y_size; i++) { + float sum_value = 0; + std::vector indices(desc->x_ndim, 0); + int global_index = i; + if (!non_reduce_axes.empty()) { + if constexpr (keepdims){ + for (int j = 0; j < non_reduce_axes.size(); j++) { + indices[non_reduce_axes[j]] = global_index / y_strides[non_reduce_axes[j]]; + global_index %= y_strides[non_reduce_axes[j]]; + } + }else{ + for (int j = 0; j < non_reduce_axes.size(); j++) { + indices[non_reduce_axes[j]] = global_index / y_strides[j]; + global_index %= y_strides[j]; + } + } + } + int64_t base_offset = 0; + for (int j = 0; j < desc->x_ndim; ++j) { + base_offset += indices[j] * desc->x_strides[j]; + } + for (j = 0; j < desc->reduce_element_num; j++){ + int64_t offset = base_offset; + uint64_t remaining = j; + int k; + for (k = 0; k < num_axes; k++){ + const int axis = axes[k]; + const uint64_t coord = remaining / reduce_axes_stride[k]; + offset += coord * input_strides[axis]; + remaining %= reduce_axes_stride[k]; + } + switch(desc->reduce_mode){ + case 0: + if constexpr (std::is_same::value){ + sum_value += f16_to_f32(x_[offset]); + } + else{ + sum_value += x_[offset]; + } + break; + case 1: + if constexpr (std::is_same::value){ + y_[i] = f32_to_f16(std::fmax(f16_to_f32(y_[i]), f16_to_f32(x_[offset]))); + } + else{ + y_[i] = std::max(y_[i], x_[offset]); + } + break; + case 2: + if constexpr (std::is_same::value){ + y_[i] = f32_to_f16(std::fmin(f16_to_f32(y_[i]), f16_to_f32(x_[offset]))); + } + else{ + y_[i] = std::min(y_[i], x_[offset]); + } + break; + } + } + if (desc->reduce_mode == 0){ + sum_value /= desc->reduce_element_num; + if constexpr (std::is_same::value){ + y_[i] = f32_to_f16(sum_value); + } + else{ + y_[i] = sum_value; + } + } + } + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuReduce(ReduceCpuDescriptor_t desc, + void *y, + void const *x, + void *dynamic_axes, + uint64_t dynamic_axes_size, + void *stream){ + if (desc->is_axes_static == true && dynamic_axes_size > 0){ + return STATUS_BAD_PARAM; + } + if (desc->is_axes_static == false && dynamic_axes_size == 0){ + if (desc->noop_with_empty_axes){ + if (desc->dt == F16){ + memcpy(y, x, desc->y_size * sizeof(uint16_t)); + return STATUS_SUCCESS; + } + else if (desc->dt == F32){ + memcpy(y, x, desc->y_size * sizeof(float)); + return STATUS_SUCCESS; + } + else return STATUS_BAD_TENSOR_DTYPE; + } + delete[] desc->axes; + delete[] desc->reduce_axes_stride; + desc->axes = new int64_t[desc->x_ndim]; + desc->reduce_axes_stride = new uint64_t[desc->x_ndim]; + std::vector full_axes(desc->x_ndim); + std::iota(full_axes.begin(), full_axes.end(), 0); + std::copy(full_axes.begin(), full_axes.end(), desc->axes); + for (int i = desc->x_ndim - 1; i >= 0; i--){ + desc->reduce_axes_stride[i] = 1; + } + for (int i = desc->x_ndim - 2; i >= 0; i--){ + desc->reduce_axes_stride[i] = desc->reduce_axes_stride[i + 1] * desc->x_shape[full_axes[i + 1]]; + } + + desc->axes_num = desc->x_ndim; + desc->reduce_element_num = 1; + for (int i = 0; i < desc->x_ndim; i++){ + desc->reduce_element_num *= desc->x_shape[i]; + } + desc->owns_axes_memory = true; + } + if (desc->is_axes_static == false && dynamic_axes_size > 0){ + auto dynamic_axes_data = reinterpret_cast(dynamic_axes); + + for (uint64_t i = 0; i < dynamic_axes_size; i++){ + if (dynamic_axes_data[i] >= desc->x_ndim){ + return STATUS_BAD_PARAM; + } + } + std::unordered_set axes_set; + for (uint64_t i = 0; i < dynamic_axes_size; i++){ + if (axes_set.find(dynamic_axes_data[i]) != axes_set.end()){ + return STATUS_BAD_PARAM; + } + axes_set.insert(dynamic_axes_data[i]); + } + std::vector axes_vec(axes_set.begin(), axes_set.end()); + std::sort(axes_vec.begin(), axes_vec.end()); + + delete[] desc->axes; + delete[] desc->reduce_axes_stride; + desc->axes = new int64_t[axes_vec.size()]; + desc->reduce_axes_stride = new uint64_t[axes_vec.size()]; + desc->axes_num = axes_vec.size(); + std::copy(axes_vec.begin(), axes_vec.end(), desc->axes); + for (int i = axes_vec.size() - 1; i >= 0; i--){ + desc->reduce_axes_stride[i] = 1; + } + for (int i = axes_vec.size() - 2; i >= 0; i--){ + desc->reduce_axes_stride[i] = desc->reduce_axes_stride[i + 1] * desc->x_shape[axes_vec[i + 1]]; + } + for (auto axis: axes_vec){ + desc->reduce_element_num *= desc->x_shape[axis]; + } + desc->owns_axes_memory = true; + } + if (desc->dt == F16) { + if (desc->keepdims == true) { + return reduce_cpu(desc, y, x); + } + return reduce_cpu(desc, y, x); + } + if (desc->dt == F32) { + if (desc->keepdims == true) { + return reduce_cpu(desc, y, x); + } + return reduce_cpu(desc, y, x); + } + return STATUS_BAD_TENSOR_DTYPE; +} \ No newline at end of file diff --git a/src/ops/reduce/cpu/reduce_cpu.h b/src/ops/reduce/cpu/reduce_cpu.h new file mode 100644 index 00000000..9b79dec0 --- /dev/null +++ b/src/ops/reduce/cpu/reduce_cpu.h @@ -0,0 +1,54 @@ +#ifndef __CPU_REDUCE_H__ +#define __CPU_REDUCE_H__ + +#include "../../../devices/cpu/common_cpu.h" +#include "operators.h" +#include +#include +#include +#include +#include + +struct ReduceCpuDescriptor { + Device device; + DataLayout dt; + uint64_t ndim; + uint64_t x_ndim; + uint64_t const *x_shape; + int64_t *axes; + int64_t const *x_strides; + uint64_t *reduce_axes_stride; + int64_t const *y_strides; + uint64_t y_size; + int reduce_mode; + int reduce_element_num; + uint64_t axes_num; + bool is_axes_static; + bool noop_with_empty_axes; + bool owns_axes_memory; + bool keepdims; +}; + +typedef struct ReduceCpuDescriptor *ReduceCpuDescriptor_t; + +infiniopStatus_t cpuCreateReduceDescriptor(infiniopHandle_t handle, + ReduceCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int64_t const *axes, + uint64_t n, + int reduce_type, + bool noop_with_empty_axes, + bool keepdims); + + +infiniopStatus_t cpuReduce(ReduceCpuDescriptor_t desc, + void *y, + void const *x, + void *dynamic_axes, + uint64_t dynamic_axes_size, + void *stream); + +infiniopStatus_t cpuDestroyReduceDescriptor(ReduceCpuDescriptor_t desc); + +#endif diff --git a/src/ops/reduce/cuda/reduce_cuda.cc b/src/ops/reduce/cuda/reduce_cuda.cc new file mode 100644 index 00000000..60e057ed --- /dev/null +++ b/src/ops/reduce/cuda/reduce_cuda.cc @@ -0,0 +1,156 @@ +#include "reduce_cuda.h" +#include "../../../devices/cuda/common_cuda.h" +#include "../../utils.h" +// need reduce_size, output_size, output_stride, input_stride + +infiniopStatus_t cudaCreateReduceDescriptor(CudaHandle_t handle, + ReduceCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int64_t const *axes, + uint64_t axes_size, + int reduce_op_type, + bool keepdims + ) { + if (x->dt != F16 && x->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (keepdims) { + if (x->ndim != y->ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + } + if (x->dt != y->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (!is_contiguous(x) || !is_contiguous(y)) { + return STATUS_BAD_TENSOR_STRIDES; + } + uint64_t element_num = 1; + uint64_t output_size = 1; + for (uint64_t i = 0; i < x->ndim; i++) { + element_num *= x->shape[i]; + } + + for (uint64_t i = 0; i < y->ndim; i++) { + output_size *= y->shape[i]; + } + uint64_t reduce_size = element_num / output_size; + + int64_t *input_strides = new int64_t[x->ndim]; + int64_t *output_strides = new int64_t[y->ndim]; + uint64_t *input_shape = new uint64_t[x->ndim]; + uint64_t *output_shape = new uint64_t[y->ndim]; + int64_t *h_axes = new int64_t[axes_size]; + uint64_t* reduce_axes_stride = new uint64_t[axes_size]; + std::fill_n(reduce_axes_stride, axes_size, 1); + for (int i = axes_size - 2; i >= 0; i--){ + reduce_axes_stride[i] = reduce_axes_stride[i + 1] * x->shape[axes[i + 1]]; + } + memcpy(h_axes, axes, axes_size * sizeof(int64_t)); + memcpy(input_shape, x->shape, x->ndim * sizeof(uint64_t)); + memcpy(output_shape, y->shape, y->ndim * sizeof(uint64_t)); + memcpy(input_strides, x->strides, x->ndim * sizeof(int64_t)); + memcpy(output_strides, y->strides, y->ndim * sizeof(int64_t)); + int prefix_size = 1, suffix_size = 1; + bool if_reduce_axes_contiguous = true; + int reduce_mode = 0; + std::vector non_reduce_axes; + for (int j = 0; j < x->ndim; ++j) { + if (!std::binary_search(axes, axes + axes_size, j)) { + non_reduce_axes.push_back(j); + } + } + int64_t *h_non_reduce_axes = new int64_t[non_reduce_axes.size()]; + std::copy(non_reduce_axes.begin(), non_reduce_axes.end(), h_non_reduce_axes); + + for (uint64_t i = 0; i < axes_size; i++) { + if (i < axes_size - 1 && axes[i] != axes[i + 1] - 1) { + if_reduce_axes_contiguous = false; + } + } + if (if_reduce_axes_contiguous) { + if (axes_size == x->ndim) { + reduce_mode = 0; + } else { + for (uint64_t i = 0; i < axes[0]; i++) { + prefix_size *= x->shape[i]; + } + for (uint64_t i = axes[axes_size - 1] + 1; i < x->ndim; i++) { + suffix_size *= x->shape[i]; + } + reduce_mode = 1; + } + } else { + reduce_mode = 2; + } + int64_t *d_non_reduce_axes; + int64_t *d_input_strides; + int64_t *d_output_strides; + int64_t *d_reduce_axes; + uint64_t *d_reduce_axes_stride; + uint64_t *d_input_shape; + uint64_t *d_output_shape; + + checkCudaErrorWithCode(cudaMalloc((void**)&d_reduce_axes_stride, axes_size * sizeof(uint64_t)), STATUS_MEMORY_NOT_ALLOCATED); + checkCudaErrorWithCode(cudaMalloc((void**)&d_reduce_axes, axes_size * sizeof(int64_t)), STATUS_MEMORY_NOT_ALLOCATED); + checkCudaErrorWithCode(cudaMalloc((void**)&d_non_reduce_axes, non_reduce_axes.size() * sizeof(int64_t)), STATUS_MEMORY_NOT_ALLOCATED); + checkCudaErrorWithCode(cudaMalloc((void**)&d_input_strides, x->ndim * sizeof(int64_t)), STATUS_MEMORY_NOT_ALLOCATED); + checkCudaErrorWithCode(cudaMalloc((void**)&d_output_strides, y->ndim * sizeof(int64_t)), STATUS_MEMORY_NOT_ALLOCATED); + checkCudaErrorWithCode(cudaMalloc((void**)&d_input_shape, x->ndim * sizeof(uint64_t)), STATUS_MEMORY_NOT_ALLOCATED); + checkCudaErrorWithCode(cudaMalloc((void**)&d_output_shape, y->ndim * sizeof(uint64_t)), STATUS_MEMORY_NOT_ALLOCATED); + + checkCudaErrorWithCode(cudaMemcpy(d_reduce_axes_stride, reduce_axes_stride, axes_size * sizeof(uint64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaMemcpy(d_reduce_axes, h_axes, axes_size * sizeof(int64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaMemcpy(d_non_reduce_axes, h_non_reduce_axes, non_reduce_axes.size() * sizeof(int64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaMemcpy(d_input_strides, input_strides, x->ndim * sizeof(int64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaMemcpy(d_output_strides, output_strides, y->ndim * sizeof(int64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaMemcpy(d_input_shape, input_shape, x->ndim * sizeof(uint64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaMemcpy(d_output_shape, output_shape, y->ndim * sizeof(uint64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + + *desc_ptr = new ReduceCudaDescriptor{ + DevNvGpu, + x->dt, + x->ndim, + y->ndim, + d_non_reduce_axes, + d_input_strides, + d_output_strides, + d_input_shape, + d_output_shape, + d_reduce_axes, + d_reduce_axes_stride, + reduce_size, + element_num, + output_size, + static_cast(reduce_op_type), + reduce_mode, + axes_size, + keepdims, + axes[0], + axes[axes_size - 1], + prefix_size, + suffix_size + }; + delete [] h_axes; + delete [] reduce_axes_stride; + delete [] h_non_reduce_axes; + delete [] input_strides; + delete [] output_strides; + delete [] input_shape; + delete [] output_shape; + return STATUS_SUCCESS; +} + +infiniopStatus_t cudaDestroyReduceDescriptor(ReduceCudaDescriptor_t desc) { + + checkCudaErrorWithCode(cudaFree((void*)desc->reduce_axes), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaFree((void*)desc->reduce_axes_stride), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaFree((void*)desc->non_reduce_axes), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaFree((void*)desc->input_strides), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaFree((void*)desc->output_strides), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaFree((void*)desc->input_shape), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaFree((void*)desc->output_shape), STATUS_EXECUTION_FAILED); + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/reduce/cuda/reduce_cuda.cu b/src/ops/reduce/cuda/reduce_cuda.cu new file mode 100644 index 00000000..b603a13f --- /dev/null +++ b/src/ops/reduce/cuda/reduce_cuda.cu @@ -0,0 +1,625 @@ +#include "../../../devices/cuda/cuda_handle.h" +#include "../../utils.h" +#include "reduce_cuda.h" +#include +#include +#include +#define WARP_SIZE 32 + +#define FLOAT4(value) (reinterpret_cast(&(value))[0]) +#define LDST128BITS(value) (reinterpret_cast(&(value))[0]) +#define FLOAT4_CONST(value) (reinterpret_cast(&(value))[0]) +#define LDST128BITS_CONST(value) (reinterpret_cast(&(value))[0]) + +enum class ReduceOp { + SUM, + MIN, + MAX, + MEAN +}; + +__global__ void divide_by_n_kernel(float* y, int N) { + y[0] /= static_cast(N); +} + +__global__ void divide_by_n_kernel(half* y, float* temp_buffer, int N) { + y[0] = __float2half(temp_buffer[0] / static_cast(N)); +} + +template +__device__ __forceinline__ T init_value() { + if constexpr (std::is_same_v) { + if constexpr (Op == ReduceOp::SUM || Op == ReduceOp::MEAN) + return 0.0f; + else if constexpr (Op == ReduceOp::MAX) + return -INFINITY; + else if constexpr (Op == ReduceOp::MIN) + return INFINITY; + } else if constexpr (std::is_same_v) { + if constexpr (Op == ReduceOp::SUM || Op == ReduceOp::MEAN) + return __float2half(0.0f); + else if constexpr (Op == ReduceOp::MAX) + return __float2half(-INFINITY); + else if constexpr (Op == ReduceOp::MIN) + return __float2half(INFINITY); + } +} + +template +__device__ __forceinline__ float reduce_op(float a, float b) { + if constexpr (Op == ReduceOp::SUM || Op == ReduceOp::MEAN) + return a + b; + else if constexpr (Op == ReduceOp::MAX) + return fmaxf(a, b); + else if constexpr (Op == ReduceOp::MIN) + return fminf(a, b); +} + +template +__device__ __forceinline__ float warp_reduce(float val) { + #pragma unroll + for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) { + if constexpr (Op == ReduceOp::SUM || Op == ReduceOp::MEAN) + val += __shfl_down_sync(0xffffffff, val, mask); + else if constexpr (Op == ReduceOp::MAX) + val = fmaxf(val, __shfl_down_sync(0xffffffff, val, mask)); + else if constexpr (Op == ReduceOp::MIN) + val = fminf(val, __shfl_down_sync(0xffffffff, val, mask)); + } + return val; +} + +template +__device__ __forceinline__ float finalize_result(float val, int count) { + if constexpr (Op == ReduceOp::MEAN) + return val / static_cast(count); + else + return val; +} + +template +__device__ __forceinline__ half reduce_op(half a, half b) { + if constexpr (Op == ReduceOp::SUM || Op == ReduceOp::MEAN) + return __hadd(a, b); + else if constexpr (Op == ReduceOp::MAX) + return __hmax(a, b); + else if constexpr (Op == ReduceOp::MIN) + return __hmin(a, b); +} + +template +__device__ __forceinline__ half warp_reduce(half val) { + #pragma unroll + for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) { + if constexpr (Op == ReduceOp::SUM || Op == ReduceOp::MEAN) + val = __hadd(val, __shfl_down_sync(0xffffffff, val, mask)); + else if constexpr (Op == ReduceOp::MAX) + val = __hmax(val, __shfl_down_sync(0xffffffff, val, mask)); + else if constexpr (Op == ReduceOp::MIN) + val = __hmin(val, __shfl_down_sync(0xffffffff, val, mask)); + } + return val; +} + +template +__device__ __forceinline__ half finalize_result(half val, int count) { + if constexpr (Op == ReduceOp::MEAN) { + // 转换为float计算,然后转回half + float fval = __half2float(val); + float result = fval / static_cast(count); + return __float2half(result); + } else { + return val; + } +} + +template +__global__ void warp_final_reduce_kernel(float *temp_in, float *y, int num_blocks) { + int tid = threadIdx.x; + float thread_result = init_value(); + + for (int i = tid; i < num_blocks; i += blockDim.x) { + if constexpr (Op == ReduceOp::MAX) + thread_result = fmaxf(thread_result, temp_in[i]); + else if constexpr (Op == ReduceOp::MIN) + thread_result = fminf(thread_result, temp_in[i]); + } + thread_result = warp_reduce(thread_result); + + if (tid == 0) + y[0] = thread_result; +} + +template +__global__ void block_final_reduce_kernel(float *temp_in, float *y, int num_blocks) { + int tid = threadIdx.x; + __shared__ float s_data[256]; + + float thread_result = init_value(); + + for (int i = tid; i < num_blocks; i += blockDim.x) { + if constexpr (Op == ReduceOp::MAX) + thread_result = fmaxf(thread_result, temp_in[i]); + else if constexpr (Op == ReduceOp::MIN) + thread_result = fminf(thread_result, temp_in[i]); + } + + s_data[tid] = thread_result; + __syncthreads(); + + for (int stride = blockDim.x/2; stride > 0; stride >>= 1) { + if (tid < stride) { + if constexpr (Op == ReduceOp::MAX) + s_data[tid] = fmaxf(s_data[tid], s_data[tid + stride]); + else if constexpr (Op == ReduceOp::MIN) + s_data[tid] = fminf(s_data[tid], s_data[tid + stride]); + } + __syncthreads(); + } + + if (tid == 0) + y[0] = s_data[0]; +} + +template +__global__ void warp_final_reduce_kernel(half *temp_in, half *y, int num_blocks) { + int tid = threadIdx.x; + half thread_result = init_value(); + for (int i = tid; i < num_blocks; i += blockDim.x) { + if constexpr (Op == ReduceOp::MAX) + thread_result = __hmax(thread_result, temp_in[i]); + else if constexpr (Op == ReduceOp::MIN) + thread_result = __hmin(thread_result, temp_in[i]); + } + thread_result = warp_reduce(thread_result); + + if (tid == 0) + y[0] = thread_result; +} + +template +__global__ void block_final_reduce_kernel(half *temp_in, half *y, int num_blocks) { + int tid = threadIdx.x; + __shared__ half s_data[256]; + + half thread_result = init_value(); + + for (int i = tid; i < num_blocks; i += blockDim.x) { + if constexpr (Op == ReduceOp::MAX) + thread_result = __hmax(thread_result, temp_in[i]); + else if constexpr (Op == ReduceOp::MIN) + thread_result = __hmin(thread_result, temp_in[i]); + } + + s_data[tid] = thread_result; + __syncthreads(); + + for (int stride = blockDim.x/2; stride > 0; stride >>= 1) { + if (tid < stride) { + if constexpr (Op == ReduceOp::MAX) + s_data[tid] = __hmax(s_data[tid], s_data[tid + stride]); + else if constexpr (Op == ReduceOp::MIN) + s_data[tid] = __hmin(s_data[tid], s_data[tid + stride]); + } + __syncthreads(); + } + + if (tid == 0) + y[0] = s_data[0]; +} + +template +__global__ void blockall_reduce_f32x4_kernel(const float *x, float *y, uint64_t N){ + int tid = threadIdx.x; + int base_idx = (tid + 64 * blockIdx.x) * 4; + constexpr int NUM_WARPS = (64 + WARP_SIZE - 1) / WARP_SIZE; + __shared__ float reduce_smem[NUM_WARPS]; + float result = init_value(); + if (base_idx < N){ + if (base_idx + 3 < N){ + float4 reg_x = FLOAT4_CONST(x[base_idx]); + if constexpr (Op == ReduceOp::SUM || Op == ReduceOp::MEAN) { + result += reg_x.x + reg_x.y + reg_x.z + reg_x.w; + } else if constexpr (Op == ReduceOp::MAX) { + result = fmaxf(result, fmaxf(fmaxf(reg_x.x, reg_x.y), fmaxf(reg_x.z, reg_x.w))); + } else if constexpr (Op == ReduceOp::MIN) { + result = fminf(result, fminf(fminf(reg_x.x, reg_x.y), fminf(reg_x.z, reg_x.w))); + } + } else { + result = reduce_op(result, x[base_idx]); + if (base_idx + 1 < N) result = reduce_op(result, x[base_idx + 1]); + if (base_idx + 2 < N) result = reduce_op(result, x[base_idx + 2]); + if (base_idx + 3 < N) result = reduce_op(result, x[base_idx + 3]); + } + } + const int warp_id = tid / WARP_SIZE; + const int lane = tid % WARP_SIZE; + result = warp_reduce(result); + if (lane == 0) { + reduce_smem[warp_id] = result; + } + __syncthreads(); + if (warp_id == 0) { + result = (lane < NUM_WARPS) ? reduce_smem[lane] : init_value(); + result = warp_reduce(result); + if (tid == 0) { + if constexpr (Op == ReduceOp::MEAN) { + atomicAdd(y, result); + } + else if constexpr (Op == ReduceOp::MAX || Op == ReduceOp::MIN) y[blockIdx.x] = result; + } + } +} + +// blockallreduce kernel for half(f16x8) using template operations +template +__global__ void blockall_reduce_f16x8_kernel(const half *x, T *y, uint64_t N){ + int tid = threadIdx.x; + int base_idx = (tid + 32 * blockIdx.x) * 8; + constexpr int NUM_WARPS = (32 + WARP_SIZE - 1) / WARP_SIZE; + __shared__ T reduce_smem[NUM_WARPS]; + T result = init_value(); + if (base_idx < N){ + if (base_idx + 7 < N){ + half pack_x[8]; + LDST128BITS(pack_x[0]) = LDST128BITS_CONST(x[base_idx]); + #pragma unroll + for (int i = 0; i < 8; i++){ + if constexpr (std::is_same::value) { + result = reduce_op(result, pack_x[i]); + } else { + result = reduce_op(result, __half2float(pack_x[i])); + } + } + } else { + for (int i = 0; i < 8 && (base_idx + i) < N; i++){ + if constexpr (std::is_same::value) { + result = reduce_op(result, x[base_idx + i]); + } else { + result = reduce_op(result, __half2float(x[base_idx + i])); + } + } + } + } + const int warp_id = tid / WARP_SIZE; + const int lane = tid % WARP_SIZE; + + result = warp_reduce(result); + + if (lane == 0) { + reduce_smem[warp_id] = result; + } + + __syncthreads(); + + if (warp_id == 0){ + result = (lane < NUM_WARPS) ? reduce_smem[lane] : init_value(); + result = warp_reduce(result); + + if (tid == 0) { + if constexpr (Op == ReduceOp::MEAN && std::is_same::value) { + atomicAdd(y, result); + } + else if constexpr (Op == ReduceOp::MAX || Op == ReduceOp::MIN) y[blockIdx.x] = result; + } + } +} + +template +__global__ void reduce_f32x4_contigous_kernel( + const float *x, float *y, + int prefix_size, int suffix_size, + uint64_t output_size, uint64_t reduce_size){ + + extern __shared__ float shared_mem_float[]; + int tid = threadIdx.x; + int output_idx = blockIdx.x; + int prefix_idx = output_idx / suffix_size; + int suffix_idx = output_idx % suffix_size; + + if (output_idx >= output_size) return; + float result = init_value(); + int base_idx = prefix_idx * reduce_size * suffix_size + suffix_idx; + // const int vector_reduce_size = reduce_size / 4; + // suffix_size equals to inner_stride + for (int i = tid; i < reduce_size; i += blockDim.x){ + result = reduce_op(result, x[base_idx + i * suffix_size]); + } + + shared_mem_float[tid] = result; + __syncthreads(); + for (int s = blockDim.x / 2; s >= 32; s >>= 1) { + if (tid < s) { + shared_mem_float[tid] = reduce_op(shared_mem_float[tid], shared_mem_float[tid + s]); + } + __syncthreads(); + } + if (tid < 32){ + result = shared_mem_float[tid]; + result = warp_reduce(result); + if (tid == 0) { + y[output_idx] = finalize_result(result, reduce_size); + + } + } +} + +template +__global__ void reduce_f16x8_contigous_kernel( + const half *x, half *y, + int prefix_size, int suffix_size, + uint64_t output_size, uint64_t reduce_size){ + extern __shared__ float shared_mem_half[]; + int tid = threadIdx.x; + int output_idx = blockIdx.x; + int prefix_idx = output_idx / suffix_size; + int suffix_idx = output_idx % suffix_size; + + if (output_idx >= output_size) return; + float result = init_value(); + int base_idx = prefix_idx * reduce_size * suffix_size + suffix_idx; + for (int i = tid; i < reduce_size; i += blockDim.x){ + result = reduce_op(result, __half2float(x[base_idx + i * suffix_size])); + } + shared_mem_half[tid] = result; + __syncthreads(); + for (int s = blockDim.x / 2; s >= 32; s >>= 1) { + if (tid < s) { + shared_mem_half[tid] = reduce_op(shared_mem_half[tid], shared_mem_half[tid + s]); + } + __syncthreads(); + } + if (tid < 32){ + result = shared_mem_half[tid]; + result = warp_reduce(result); + // each block has one result + if (tid == 0) { + float fresult = finalize_result(result, reduce_size); + y[output_idx] = __float2half(fresult); + } + } +} +/* +[2, 3, 4, 5] tensor axes = [0, ] -> output tensor [1, 3, 1, 5] +output_idx = 0 + + +*/ +template +__global__ void reduce_f32_kernel( + const float *x, float *y, + const int64_t *x_strides, const int64_t *y_strides, + const uint64_t reduce_size, const uint64_t output_size, + const int64_t *non_reduce_axes, + const uint64_t axes_size, + const uint64_t *reduce_axes_stride, + const int64_t *reduce_axes, const int64_t INPUT_NDIM, + const uint64_t element_num +){ + extern __shared__ float shared_mem_float[]; + int tid = threadIdx.x; + int output_idx = blockIdx.x; + if (output_idx >= output_size) return; + float result = init_value(); + uint64_t out_coords[10] = {0}; + if constexpr (KeepDims) { + for (int i = 0; i < INPUT_NDIM - axes_size; i++) { + out_coords[non_reduce_axes[i]] = output_idx / y_strides[non_reduce_axes[i]]; + output_idx %= y_strides[non_reduce_axes[i]]; + } + } else { + for (int i = 0; i < INPUT_NDIM - axes_size; i++) { + out_coords[non_reduce_axes[i]] = output_idx / y_strides[i]; + output_idx %= y_strides[i]; + } + } + uint64_t baseoffset = 0; + for (int i = 0; i < INPUT_NDIM; i++) { + baseoffset += out_coords[i] * x_strides[i]; + } + for (uint64_t i = tid; i < reduce_size; i += blockDim.x){ + uint64_t offset = baseoffset; + uint64_t remaining = i; + for (int j = 0; j < axes_size; j++) { + const uint64_t axis = reduce_axes[j]; + const uint64_t coord = remaining / reduce_axes_stride[j]; + offset += coord * x_strides[axis]; + remaining %= reduce_axes_stride[j]; + } + result = reduce_op(result, x[offset]); + } + shared_mem_float[tid] = result; + __syncthreads(); + for (int s = blockDim.x / 2; s >= 32; s >>= 1) { + if (tid < s) { + shared_mem_float[tid] = reduce_op(shared_mem_float[tid], shared_mem_float[tid + s]); + } + __syncthreads(); + } + if (tid < 32){ + result = shared_mem_float[tid]; + result = warp_reduce(result); + if (tid == 0) { + y[blockIdx.x] = finalize_result(result, reduce_size); + } + } +} + + +template +infiniopStatus_t reduce_nv_gpu( + ReduceCudaDescriptor_t desc, + void *y, + void const *x, + void *stream){ + uint64_t N = desc->element_num; + if (desc->reduce_mode == 0){ + if constexpr(std::is_same::value){ + dim3 block(256 / 4); + dim3 grid((N + block.x * 4 - 1) / (block.x * 4)); + if (desc->reduce_op_type == 0) { + blockall_reduce_f32x4_kernel<<>>(reinterpret_cast(x), reinterpret_cast(y), N); + divide_by_n_kernel<<<1, 1, 0, (cudaStream_t)stream>>>( + reinterpret_cast(y), N); + } else if (desc->reduce_op_type == 1 || desc->reduce_op_type == 2) { + float* temp_buffer; + cudaMalloc(&temp_buffer, grid.x * sizeof(float)); + if (desc->reduce_op_type == 1) { + blockall_reduce_f32x4_kernel<<>>(reinterpret_cast(x), temp_buffer, N); + } else if (desc->reduce_op_type == 2) { + blockall_reduce_f32x4_kernel<<>>(reinterpret_cast(x), temp_buffer, N); + } + int num_blocks = grid.x; + if (num_blocks <= 32) { + if (desc->reduce_op_type == 1) { // MAX + warp_final_reduce_kernel<<<1, 32, 0, (cudaStream_t)stream>>>( + temp_buffer, reinterpret_cast(y), num_blocks); + } else { + warp_final_reduce_kernel<<<1, 32, 0, (cudaStream_t)stream>>>( + temp_buffer, reinterpret_cast(y), num_blocks); + } + } else { + if (desc->reduce_op_type == 1) { + block_final_reduce_kernel<<<1, 256, 0, (cudaStream_t)stream>>>( + temp_buffer, reinterpret_cast(y), num_blocks); + } else { + block_final_reduce_kernel<<<1, 256, 0, (cudaStream_t)stream>>>( + temp_buffer, reinterpret_cast(y), num_blocks); + } + } + cudaFree(temp_buffer); + } + } else { + dim3 block(256 / 8); + dim3 grid((N + block.x * 8 - 1) / (block.x * 8)); + if (desc->reduce_op_type == 0) { + float* temp_buffer; + cudaMalloc(&temp_buffer, sizeof(float)); + blockall_reduce_f16x8_kernel<<>>(reinterpret_cast(x), temp_buffer, N); + divide_by_n_kernel<<<1, 1, 0, (cudaStream_t)stream>>>( + reinterpret_cast(y), temp_buffer, N); + } else if (desc->reduce_op_type == 1 || desc->reduce_op_type == 2) { + half *temp_buffer; + cudaMalloc(&temp_buffer, grid.x * sizeof(half)); + if (desc->reduce_op_type == 1) { + blockall_reduce_f16x8_kernel<<>>(reinterpret_cast(x), reinterpret_cast(temp_buffer), N); + } else if (desc->reduce_op_type == 2) { + blockall_reduce_f16x8_kernel<<>>(reinterpret_cast(x), reinterpret_cast(temp_buffer), N); + } + int num_blocks = grid.x; + if (num_blocks <= 32) { + if (desc->reduce_op_type == 1) { + warp_final_reduce_kernel<<<1, 32, 0, (cudaStream_t)stream>>>( + temp_buffer, reinterpret_cast(y), num_blocks); + } else { + warp_final_reduce_kernel<<<1, 32, 0, (cudaStream_t)stream>>>( + temp_buffer, reinterpret_cast(y), num_blocks); + } + } else { + if (desc->reduce_op_type == 1) { + block_final_reduce_kernel<<<1, 256, 0, (cudaStream_t)stream>>>( + temp_buffer, reinterpret_cast(y), num_blocks); + } else { + block_final_reduce_kernel<<<1, 256, 0, (cudaStream_t)stream>>>( + temp_buffer, reinterpret_cast(y), num_blocks); + } + } + cudaFree(temp_buffer); + } + } + } + // contiguous axes + else if (desc->reduce_mode == 1){ + if constexpr(std::is_same::value){ + dim3 block(128); + dim3 grid(desc->output_size); + size_t shared_mem_size = block.x * sizeof(float); + if (desc->reduce_op_type == 0) { + reduce_f32x4_contigous_kernel<<>>( + reinterpret_cast(x), reinterpret_cast(y), desc->prefix_size, desc->suffix_size, desc->output_size, desc->reduce_size); + } else if (desc->reduce_op_type == 1) { + reduce_f32x4_contigous_kernel<<>>( + reinterpret_cast(x), reinterpret_cast(y), desc->prefix_size, desc->suffix_size, desc->output_size, desc->reduce_size); + } else if (desc->reduce_op_type == 2) { + reduce_f32x4_contigous_kernel<<>>( + reinterpret_cast(x), reinterpret_cast(y), desc->prefix_size, desc->suffix_size, desc->output_size, desc->reduce_size); + } + } else { + dim3 block(128); + dim3 grid(desc->output_size); + size_t shared_mem_size = block.x * sizeof(half); + if (desc->reduce_op_type == 0) { + reduce_f16x8_contigous_kernel<<>>( + reinterpret_cast(x), reinterpret_cast(y), desc->prefix_size, desc->suffix_size, desc->output_size, desc->reduce_size); + }else if (desc->reduce_op_type == 1) { + reduce_f16x8_contigous_kernel<<>>( + reinterpret_cast(x), reinterpret_cast(y), desc->prefix_size, desc->suffix_size, desc->output_size, desc->reduce_size); + } else if (desc->reduce_op_type == 2) { + reduce_f16x8_contigous_kernel<<>>( + reinterpret_cast(x), reinterpret_cast(y), desc->prefix_size, desc->suffix_size, desc->output_size, desc->reduce_size); + } + } + } + // not contiguous axes + else if (desc->reduce_mode == 2){ + dim3 block(128); + dim3 grid(desc->output_size); + size_t shared_mem_size = block.x * sizeof(float); + if constexpr(std::is_same::value){ + if (desc->keepdims){ + if (desc->reduce_op_type == 0) { + reduce_f32_kernel<<>>( + reinterpret_cast(x), reinterpret_cast(y), + desc->input_strides, desc->output_strides, desc->reduce_size, desc->output_size, + desc->non_reduce_axes, desc->axes_size, desc->reduce_axes_stride, desc->reduce_axes, desc->input_ndim, desc->element_num); + } + else if (desc->reduce_op_type == 1) { + reduce_f32_kernel<<>>( + reinterpret_cast(x), reinterpret_cast(y), + desc->input_strides, desc->output_strides, desc->reduce_size, desc->output_size, + desc->non_reduce_axes, desc->axes_size, desc->reduce_axes_stride, desc->reduce_axes, desc->input_ndim, desc->element_num); + } + else if (desc->reduce_op_type == 2) { + reduce_f32_kernel<<>>( + reinterpret_cast(x), reinterpret_cast(y), + desc->input_strides, desc->output_strides, desc->reduce_size, desc->output_size, + desc->non_reduce_axes, desc->axes_size, desc->reduce_axes_stride, desc->reduce_axes, desc->input_ndim, desc->element_num); + } + }else{ + if (desc->reduce_op_type == 0) { + reduce_f32_kernel<<>>( + reinterpret_cast(x), reinterpret_cast(y), + desc->input_strides, desc->output_strides, desc->reduce_size, desc->output_size, + desc->non_reduce_axes, desc->axes_size, desc->reduce_axes_stride, desc->reduce_axes, desc->input_ndim, desc->element_num); + } + else if (desc->reduce_op_type == 1) { + reduce_f32_kernel<<>>( + reinterpret_cast(x), reinterpret_cast(y), + desc->input_strides, desc->output_strides, desc->reduce_size, desc->output_size, + desc->non_reduce_axes, desc->axes_size, desc->reduce_axes_stride, desc->reduce_axes, desc->input_ndim, desc->element_num); + } + else if (desc->reduce_op_type == 2) { + reduce_f32_kernel<<>>( + reinterpret_cast(x), reinterpret_cast(y), + desc->input_strides, desc->output_strides, desc->reduce_size, desc->output_size, + desc->non_reduce_axes, desc->axes_size, desc->reduce_axes_stride, desc->reduce_axes, desc->input_ndim, desc->element_num); + } + } + } + } + return STATUS_SUCCESS; +} + +infiniopStatus_t cudaReduce( + ReduceCudaDescriptor_t desc, + void *y, + void const *x, + void *stream){ + if (desc->dtype == F16) { + return reduce_nv_gpu(desc, y, x, stream); + } + if (desc->dtype == F32) { + return reduce_nv_gpu(desc, y, x, stream); + } + return STATUS_BAD_TENSOR_DTYPE; +} \ No newline at end of file diff --git a/src/ops/reduce/cuda/reduce_cuda.h b/src/ops/reduce/cuda/reduce_cuda.h new file mode 100644 index 00000000..9b0b2741 --- /dev/null +++ b/src/ops/reduce/cuda/reduce_cuda.h @@ -0,0 +1,61 @@ +#ifndef __CUDA_REDUCE_H__ +#define __CUDA_REDUCE_H__ + +#include "../../../devices/cuda/cuda_handle.h" +#include "operators.h" +#include + +typedef struct ReduceCudaDescriptor { + Device device; + DT dtype; + uint64_t input_ndim; + uint64_t output_ndim; + int64_t *non_reduce_axes; + int64_t *input_strides; + int64_t *output_strides; + uint64_t *input_shape; + uint64_t *output_shape; + int64_t *reduce_axes; + uint64_t *reduce_axes_stride; + uint64_t reduce_size; + uint64_t element_num; + uint64_t output_size; + int reduce_op_type; + int reduce_mode; // output_size * reduce_size = element_num + uint64_t axes_size; + bool keepdims; + int64_t start_axis; + int64_t end_axis; + int prefix_size; + int suffix_size; +} ReduceCudaDescriptor; + +typedef struct ReduceCudaDescriptor *ReduceCudaDescriptor_t; + + +enum ReduceOpType { + REDUCE_SUM = 0, + REDUCE_MIN = 1, + REDUCE_MAX = 2, + REDUCE_MEAN = 3 +}; + +infiniopStatus_t cudaCreateReduceDescriptor(CudaHandle_t handle, + ReduceCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t y, + int64_t const *axes, + uint64_t axes_size, + int reduce_op_type, + bool keepdims + ); + + +infiniopStatus_t cudaReduce(ReduceCudaDescriptor_t desc, + void *y, + void const *x, + void *stream); + +infiniopStatus_t cudaDestroyReduceDescriptor(ReduceCudaDescriptor_t desc); + +#endif// __CUDA_MATMUL_H__ diff --git a/src/ops/reduce/operator.cc b/src/ops/reduce/operator.cc new file mode 100644 index 00000000..fcd38395 --- /dev/null +++ b/src/ops/reduce/operator.cc @@ -0,0 +1,64 @@ +#include "../utils.h" +#include "operators.h" +#include "reduce.h" + +#ifdef ENABLE_CPU +#include "cpu/reduce_cpu.h" +#endif + +#ifdef ENABLE_NV_GPU +#include "cuda/reduce_cuda.h" +#endif + + +__C infiniopStatus_t infiniopCreateReduceDescriptor( + infiniopHandle_t handle, + infiniopReduceDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int64_t const *axes, + uint64_t n, + bool keepdims, + bool noop_with_empty_axes, + int reduce_type) { + switch (handle->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuCreateReduceDescriptor(handle, (ReduceCpuDescriptor_t *) desc_ptr, y, x, axes, n, reduce_type, noop_with_empty_axes, keepdims); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: + return cudaCreateReduceDescriptor((CudaHandle_t)handle, (ReduceCudaDescriptor_t *) desc_ptr, y, x, axes, n, reduce_type, keepdims); +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopReduce(infiniopReduceDescriptor_t desc, void *y, const void *x, void *dynamic_axes, uint64_t dynamic_axes_size, void *stream) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuReduce((ReduceCpuDescriptor_t) desc, y, x, dynamic_axes, dynamic_axes_size, stream); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: + return cudaReduce((ReduceCudaDescriptor_t) desc, y, x, stream); +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopDestroyReduceDescriptor(infiniopReduceDescriptor_t desc){ + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuDestroyReduceDescriptor((ReduceCpuDescriptor_t) desc); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: + return cudaDestroyReduceDescriptor((ReduceCudaDescriptor_t) desc); +#endif + } + return STATUS_BAD_DEVICE; + +} \ No newline at end of file diff --git a/src/ops/reduce/reduce.h b/src/ops/reduce/reduce.h new file mode 100644 index 00000000..47e49aa5 --- /dev/null +++ b/src/ops/reduce/reduce.h @@ -0,0 +1,25 @@ +#ifndef REDUCE_H +#define REDUCE_H + +#include "export.h" +#include "operators.h" + +typedef struct ReduceDescriptor { + Device device; +} ReduceDescriptor; +typedef ReduceDescriptor *infiniopReduceDescriptor_t; + +__C infiniopStatus_t infiniopCreateReduceDescriptor(infiniopHandle_t handle, + infiniopReduceDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int64_t const *axes, + uint64_t n, + bool keepdims, + bool noop_with_empty_axes, + int reduce_type); + +__C infiniopStatus_t infiniopReduce(infiniopReduceDescriptor_t desc, void *y, const void *x, void *dynamic_axes, uint64_t dynamic_axes_size, void *stream); + +__C infiniopStatus_t infiniopDestroyReduceDescriptor(infiniopReduceDescriptor_t desc); +#endif \ No newline at end of file diff --git a/src/ops/reducemax/operator.cc b/src/ops/reducemax/operator.cc new file mode 100644 index 00000000..4b6c14ed --- /dev/null +++ b/src/ops/reducemax/operator.cc @@ -0,0 +1,41 @@ +#include "../utils.h" +#include "../reduce/reduce.h" +#include "ops/reducemax/reducemax.h" + +struct _ReducemaxDescriptor { + Device device; + infiniopReduceDescriptor_t reduce_desc; +}; + +typedef struct _ReducemaxDescriptor *_ReducemaxDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateReducemaxDescriptor( + infiniopHandle_t handle, + infiniopReducemaxDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int64_t const *axes, + uint64_t n, + bool keepdims, + bool noop_with_empty_axes + ) { + infiniopReduceDescriptor_t reduce_desc; + CHECK_STATUS(infiniopCreateReduceDescriptor(handle, &reduce_desc, y, x, axes, n, keepdims, noop_with_empty_axes, 1), STATUS_SUCCESS); + *(_ReducemaxDescriptor_t *) desc_ptr = new _ReducemaxDescriptor{ + handle->device, + reduce_desc + }; + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopReducemax(infiniopReducemaxDescriptor_t desc, void *y, const void *x, void *dynamic_axes, uint64_t dynamic_axes_size, void *stream) { + auto _desc = (_ReducemaxDescriptor_t) desc; + CHECK_STATUS(infiniopReduce(_desc->reduce_desc, y, x, dynamic_axes, dynamic_axes_size, stream), STATUS_SUCCESS); + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopDestroyReducemaxDescriptor(infiniopReducemaxDescriptor_t desc) { + CHECK_STATUS(infiniopDestroyReduceDescriptor(((_ReducemaxDescriptor_t) desc)->reduce_desc), STATUS_SUCCESS); + delete desc; + return STATUS_SUCCESS; +} \ No newline at end of file diff --git a/src/ops/reducemean/operator.cc b/src/ops/reducemean/operator.cc new file mode 100644 index 00000000..00b62374 --- /dev/null +++ b/src/ops/reducemean/operator.cc @@ -0,0 +1,41 @@ +#include "../utils.h" +#include "../reduce/reduce.h" +#include "ops/reducemean/reducemean.h" + +struct _ReducemeanDescriptor { + Device device; + infiniopReduceDescriptor_t reduce_desc; +}; + +typedef struct _ReducemeanDescriptor *_ReducemeanDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateReducemeanDescriptor( + infiniopHandle_t handle, + infiniopReducemeanDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int64_t const *axes, + uint64_t n, + bool keepdims, + bool noop_with_empty_axes + ) { + infiniopReduceDescriptor_t reduce_desc; + CHECK_STATUS(infiniopCreateReduceDescriptor(handle, &reduce_desc, y, x, axes, n, keepdims, noop_with_empty_axes, 0), STATUS_SUCCESS); + *(_ReducemeanDescriptor_t *) desc_ptr = new _ReducemeanDescriptor{ + handle->device, + reduce_desc + }; + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopReducemean(infiniopReducemeanDescriptor_t desc, void *y, void const *x, void *dynamic_axes, uint64_t dynamic_axes_size, void *stream) { + auto _desc = (_ReducemeanDescriptor_t) desc; + CHECK_STATUS(infiniopReduce(_desc->reduce_desc, y, x, dynamic_axes, dynamic_axes_size, stream), STATUS_SUCCESS); + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopDestroyReducemeanDescriptor(infiniopReducemeanDescriptor_t desc) { + CHECK_STATUS(infiniopDestroyReduceDescriptor(((_ReducemeanDescriptor_t) desc)->reduce_desc), STATUS_SUCCESS); + delete desc; + return STATUS_SUCCESS; +} \ No newline at end of file diff --git a/src/ops/reducemin/operator.cc b/src/ops/reducemin/operator.cc new file mode 100644 index 00000000..b17ff2b8 --- /dev/null +++ b/src/ops/reducemin/operator.cc @@ -0,0 +1,40 @@ +#include "../utils.h" +#include "../reduce/reduce.h" +#include "ops/reducemin/reducemin.h" + +struct _ReduceminDescriptor { + Device device; + infiniopReduceDescriptor_t reduce_desc; +}; + +typedef struct _ReduceminDescriptor *_ReduceminDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateReduceminDescriptor( + infiniopHandle_t handle, + infiniopReduceminDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int64_t const *axes, + uint64_t n, + bool keepdims, + bool noop_with_empty_axes + ) { + infiniopReduceDescriptor_t reduce_desc; + CHECK_STATUS(infiniopCreateReduceDescriptor(handle, &reduce_desc, y, x, axes, n, keepdims, noop_with_empty_axes,2), STATUS_SUCCESS); + *(_ReduceminDescriptor_t *) desc_ptr = new _ReduceminDescriptor{ + handle->device, + reduce_desc + }; + return STATUS_SUCCESS; +} +__C __export infiniopStatus_t infiniopReducemin(infiniopReduceminDescriptor_t desc, void *y, void const *x, void *dynamic_axes, uint64_t dynamic_axes_size, void *stream) { + auto _desc = (_ReduceminDescriptor_t) desc; + CHECK_STATUS(infiniopReduce(_desc->reduce_desc, y, x, dynamic_axes, dynamic_axes_size, stream), STATUS_SUCCESS); + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopDestroyReduceminDescriptor(infiniopReduceminDescriptor_t desc) { + CHECK_STATUS(infiniopDestroyReduceDescriptor(((_ReduceminDescriptor_t) desc)->reduce_desc), STATUS_SUCCESS); + delete desc; + return STATUS_SUCCESS; +} \ No newline at end of file diff --git a/src/ops/where/cpu/where_cpu.cc b/src/ops/where/cpu/where_cpu.cc new file mode 100644 index 00000000..5fd834c4 --- /dev/null +++ b/src/ops/where/cpu/where_cpu.cc @@ -0,0 +1,145 @@ +#include "where_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "../../utils.h" + +infiniopStatus_t cpuCreateWhereDescriptor(infiniopHandle_t handle, + WhereCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t dst, + infiniopTensorDescriptor_t src1, + infiniopTensorDescriptor_t src2, + infiniopTensorDescriptor_t condition + ){ + if (!isValidBroadcastShape(dst, src1) || !isValidBroadcastShape(dst, src2) || !isValidBroadcastShape(dst, condition)) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (dst->dt != F16 && dst->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (src1->dt != F16 && src1->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (src2->dt != F16 && src2->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (condition->dt != U8) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (!is_contiguous(dst) || !is_contiguous(src1) || !is_contiguous(src2) || !is_contiguous(condition)) { + return STATUS_BAD_TENSOR_STRIDES; + } + if (dst->dt != src1->dt || dst->dt != src2->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + uint64_t *src1_shape = new uint64_t[dst->ndim]; + uint64_t *src2_shape = new uint64_t[dst->ndim]; + uint64_t *condition_shape = new uint64_t[dst->ndim]; + uint64_t *dst_shape = new uint64_t[dst->ndim]; + + int64_t *dst_strides = new int64_t[dst->ndim]; + int64_t *src1_strides = new int64_t[src1->ndim]; + int64_t *src2_strides = new int64_t[src2->ndim]; + int64_t *condition_strides = new int64_t[condition->ndim]; + uint64_t dst_ndim = dst->ndim; + uint64_t element_num = 1; + for (uint64_t i = 0; i < dst->ndim; i++) { + element_num *= dst->shape[i]; + } + + memcpy(src1_shape, src1->shape, src1->ndim * sizeof(uint64_t)); + memcpy(src2_shape, src2->shape, src2->ndim * sizeof(uint64_t)); + memcpy(condition_shape, condition->shape, condition->ndim * sizeof(uint64_t)); + memcpy(dst_shape, dst->shape, dst->ndim * sizeof(uint64_t)); + + memcpy(dst_strides, dst->strides, dst->ndim * sizeof(int64_t)); + memcpy(src1_strides, src1->strides, src1->ndim * sizeof(int64_t)); + memcpy(src2_strides, src2->strides, src2->ndim * sizeof(int64_t)); + memcpy(condition_strides, condition->strides, condition->ndim * sizeof(int64_t)); + *desc_ptr = new WhereCpuDescriptor{ + handle->device, + dst->dt, + src1_shape, + src2_shape, + condition_shape, + dst_shape, + dst_strides, + src1_strides, + src2_strides, + condition_strides, + dst_ndim, + src1->ndim, + src2->ndim, + condition->ndim, + element_num + }; + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuDestroyWhereDescriptor(WhereCpuDescriptor_t desc){ + delete[] desc->src1_shape; + delete[] desc->src2_shape; + delete[] desc->condition_shape; + delete[] desc->dst_shape; + delete[] desc->dst_strides; + delete[] desc->src1_strides; + delete[] desc->src2_strides; + delete[] desc->condition_strides; + delete desc; + return STATUS_SUCCESS; +} +inline uint64_t broadcast_map( + uint64_t idx, + const uint64_t* dst_shape, + uint64_t dst_ndim, + const uint64_t* input_shape, + const int64_t* input_strides, + uint64_t input_ndim +) { + uint64_t index = 0; + const int offset = dst_ndim - input_ndim; + for (int i = dst_ndim - 1; i >= 0; idx /= dst_shape[i--]) { + const uint64_t coord = idx % dst_shape[i]; + if (i >= offset) { + const int dim = i - offset; + const uint64_t size = input_shape[dim]; + index += (size == 1 ? 0 : coord % size) * input_strides[dim]; + } + } + return index; +} + + +template +infiniopStatus_t where_cpu(WhereCpuDescriptor_t desc, + void *dst, + void const *src1, + void const *src2, + void const *condition, + void *stream){ + auto dst_ = reinterpret_cast(dst); + auto src1_ = reinterpret_cast(src1); + auto src2_ = reinterpret_cast(src2); + auto condition_ = reinterpret_cast(condition); + #pragma omp parallel for + for (uint64_t i = 0; i < desc->element_num; i++) { + uint64_t condition_index = broadcast_map(i, desc->dst_shape, desc->dst_ndim, desc->condition_shape, desc->condition_strides, desc->condition_ndim); + uint64_t src1_index = broadcast_map(i, desc->dst_shape, desc->dst_ndim, desc->src1_shape, desc->src1_strides, desc->src1_ndim); + uint64_t src2_index = broadcast_map(i, desc->dst_shape, desc->dst_ndim, desc->src2_shape, desc->src2_strides, desc->src2_ndim); + dst_[i] = condition_[condition_index] ? src1_[src1_index] : src2_[src2_index]; + } + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuWhere(WhereCpuDescriptor_t desc, + void *dst, + void const *src1, + void const *src2, + void const *condition, + void *stream){ + if (desc->dtype == F16) { + return where_cpu(desc, dst, src1, src2, condition, stream); + } + if (desc->dtype == F32) { + return where_cpu(desc, dst, src1, src2, condition, stream); + } + return STATUS_BAD_TENSOR_DTYPE; +} \ No newline at end of file diff --git a/src/ops/where/cpu/where_cpu.h b/src/ops/where/cpu/where_cpu.h new file mode 100644 index 00000000..f422980c --- /dev/null +++ b/src/ops/where/cpu/where_cpu.h @@ -0,0 +1,42 @@ +#ifndef __CPU_WHERE_H__ +#define __CPU_WHERE_H__ + +#include "operators.h" +struct WhereCpuDescriptor { + Device device; + DT dtype; + uint64_t const *src1_shape; + uint64_t const *src2_shape; + uint64_t const *condition_shape; + uint64_t const *dst_shape; + int64_t const *dst_strides; + int64_t const *src1_strides; + int64_t const *src2_strides; + int64_t const *condition_strides; + uint64_t dst_ndim; + uint64_t src1_ndim; + uint64_t src2_ndim; + uint64_t condition_ndim; + uint64_t element_num; +}; + +typedef struct WhereCpuDescriptor *WhereCpuDescriptor_t; + +infiniopStatus_t cpuCreateWhereDescriptor(infiniopHandle_t handle, + WhereCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t dst, + infiniopTensorDescriptor_t src1, + infiniopTensorDescriptor_t src2, + infiniopTensorDescriptor_t condition + ); + +infiniopStatus_t cpuWhere(WhereCpuDescriptor_t desc, + void *dst, + void const *src1, + void const *src2, + void const *condition, + void *stream); + +infiniopStatus_t cpuDestroyWhereDescriptor(WhereCpuDescriptor_t desc); + +#endif diff --git a/src/ops/where/cuda/where_cuda.cc b/src/ops/where/cuda/where_cuda.cc new file mode 100644 index 00000000..35701773 --- /dev/null +++ b/src/ops/where/cuda/where_cuda.cc @@ -0,0 +1,120 @@ +#include "where_cuda.h" +#include "../../../devices/cuda/common_cuda.h" +#include "../../utils.h" + +infiniopStatus_t cudaCreateWhereDescriptor(CudaHandle_t handle, + WhereCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t dst, + infiniopTensorDescriptor_t src1, + infiniopTensorDescriptor_t src2, + infiniopTensorDescriptor_t condition + ) { + if (!isValidBroadcastShape(dst, src1) || !isValidBroadcastShape(dst, src2) || !isValidBroadcastShape(dst, condition)) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (dst->dt != F16 && dst->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (src1->dt != F16 && src1->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (src2->dt != F16 && src2->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (condition->dt != U8) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (!is_contiguous(dst) || !is_contiguous(src1) || !is_contiguous(src2) || !is_contiguous(condition)) { + return STATUS_BAD_TENSOR_STRIDES; + } + if (dst->dt != src1->dt || dst->dt != src2->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + uint64_t *src1_shape = new uint64_t[dst->ndim]; + uint64_t *src2_shape = new uint64_t[dst->ndim]; + uint64_t *condition_shape = new uint64_t[dst->ndim]; + uint64_t *dst_shape = new uint64_t[dst->ndim]; + + int64_t *dst_strides = new int64_t[dst->ndim]; + int64_t *src1_strides = new int64_t[src1->ndim]; + int64_t *src2_strides = new int64_t[src2->ndim]; + int64_t *condition_strides = new int64_t[condition->ndim]; + uint64_t dst_ndim = dst->ndim; + uint64_t element_num = 1; + for (uint64_t i = 0; i < dst->ndim; i++) { + element_num *= dst->shape[i]; + } + + memcpy(src1_shape, src1->shape, src1->ndim * sizeof(uint64_t)); + memcpy(src2_shape, src2->shape, src2->ndim * sizeof(uint64_t)); + memcpy(condition_shape, condition->shape, condition->ndim * sizeof(uint64_t)); + memcpy(dst_shape, dst->shape, dst->ndim * sizeof(uint64_t)); + + memcpy(dst_strides, dst->strides, dst->ndim * sizeof(int64_t)); + memcpy(src1_strides, src1->strides, src1->ndim * sizeof(int64_t)); + memcpy(src2_strides, src2->strides, src2->ndim * sizeof(int64_t)); + memcpy(condition_strides, condition->strides, condition->ndim * sizeof(int64_t)); + + uint64_t *src1_device_shape, *src2_device_shape, *condition_device_shape, *dst_device_shape; + int64_t *dst_device_strides, *src1_device_strides, *src2_device_strides, *condition_device_strides; + checkCudaErrorWithCode(cudaMalloc((void**)&src1_device_shape, src1->ndim * sizeof(uint64_t)), STATUS_MEMORY_NOT_ALLOCATED); + checkCudaErrorWithCode(cudaMalloc((void**)&src2_device_shape, src2->ndim * sizeof(uint64_t)), STATUS_MEMORY_NOT_ALLOCATED); + checkCudaErrorWithCode(cudaMalloc((void**)&condition_device_shape, condition->ndim * sizeof(uint64_t)), STATUS_MEMORY_NOT_ALLOCATED); + checkCudaErrorWithCode(cudaMalloc((void**)&dst_device_shape, dst->ndim * sizeof(uint64_t)), STATUS_MEMORY_NOT_ALLOCATED); + + checkCudaErrorWithCode(cudaMalloc((void**)&dst_device_strides, dst->ndim * sizeof(int64_t)), STATUS_MEMORY_NOT_ALLOCATED); + checkCudaErrorWithCode(cudaMalloc((void**)&src1_device_strides, src1->ndim * sizeof(int64_t)), STATUS_MEMORY_NOT_ALLOCATED); + checkCudaErrorWithCode(cudaMalloc((void**)&src2_device_strides, src2->ndim * sizeof(int64_t)), STATUS_MEMORY_NOT_ALLOCATED); + checkCudaErrorWithCode(cudaMalloc((void**)&condition_device_strides, condition->ndim * sizeof(int64_t)), STATUS_MEMORY_NOT_ALLOCATED); + + checkCudaErrorWithCode(cudaMemcpy(src1_device_shape, src1_shape, src1->ndim * sizeof(uint64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaMemcpy(src2_device_shape, src2_shape, src2->ndim * sizeof(uint64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaMemcpy(condition_device_shape, condition_shape, condition->ndim * sizeof(uint64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaMemcpy(dst_device_shape, dst->shape, dst->ndim * sizeof(uint64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + + checkCudaErrorWithCode(cudaMemcpy(dst_device_strides, dst_strides, dst->ndim * sizeof(int64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaMemcpy(src1_device_strides, src1_strides, src1->ndim * sizeof(int64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaMemcpy(src2_device_strides, src2_strides, src2->ndim * sizeof(int64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaMemcpy(condition_device_strides, condition_strides, condition->ndim * sizeof(int64_t), cudaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + + *desc_ptr = new WhereCudaDescriptor{ + handle->device, + dst->dt, + src1_device_shape, + src2_device_shape, + condition_device_shape, + dst_device_shape, + dst_device_strides, + src1_device_strides, + src2_device_strides, + condition_device_strides, + dst_ndim, + src1->ndim, + src2->ndim, + condition->ndim, + element_num + }; + delete[] src1_shape; + delete[] src2_shape; + delete[] condition_shape; + delete[] dst_shape; + delete[] dst_strides; + delete[] src1_strides; + delete[] src2_strides; + delete[] condition_strides; + return STATUS_SUCCESS; +} + + +infiniopStatus_t cudaDestroyWhereDescriptor(WhereCudaDescriptor_t desc) { + checkCudaErrorWithCode(cudaFree((void*)desc->src1_shape), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaFree((void*)desc->src2_shape), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaFree((void*)desc->condition_shape), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaFree((void*)desc->dst_shape), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaFree((void*)desc->dst_strides), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaFree((void*)desc->src1_strides), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaFree((void*)desc->src2_strides), STATUS_EXECUTION_FAILED); + checkCudaErrorWithCode(cudaFree((void*)desc->condition_strides), STATUS_EXECUTION_FAILED); + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/where/cuda/where_cuda.cu b/src/ops/where/cuda/where_cuda.cu new file mode 100644 index 00000000..d1873803 --- /dev/null +++ b/src/ops/where/cuda/where_cuda.cu @@ -0,0 +1,184 @@ +#include "../../../devices/cuda/cuda_handle.h" +#include "../../utils.h" +#include "where_cuda.h" +#include + +#define LDST128BITS(value) (reinterpret_cast(&(value))[0]) +#define FLOAT4(value) (reinterpret_cast(&(value))[0]) + +__device__ __forceinline__ uint64_t broadcast_map( + uint64_t idx, + const uint64_t* dst_shape, + uint64_t dst_ndim, + const uint64_t* input_shape, + const int64_t* input_strides, + uint64_t input_ndim +) { + uint64_t index = 0; + const int offset = dst_ndim - input_ndim; + for (int i = dst_ndim - 1; i >= 0; idx /= dst_shape[i--]) { + const uint64_t coord = idx % dst_shape[i]; + if (i >= offset) { + const int dim = i - offset; + const uint64_t size = input_shape[dim]; + index += (size == 1 ? 0 : coord % size) * input_strides[dim]; + } + } + return index; +} + + + +__global__ void where_f32x4_kernel(const float* src1, uint64_t const *src1_shape, + const float* src2, uint64_t const *src2_shape, + float* dst, uint64_t const *dst_shape, + const uint8_t* condition, uint64_t const *condition_shape, + uint64_t dst_ndim, uint64_t src1_ndim, + uint64_t src2_ndim, uint64_t condition_ndim, + int64_t const *condition_strides, + int64_t const *src1_strides, + int64_t const *src2_strides, + int64_t const *dst_strides, + int N) { + + int idx_base = 4 * (blockIdx.x * blockDim.x + threadIdx.x); + if (idx_base >= N) return; + bool valid[4] = { + (idx_base + 0) < N, + (idx_base + 1) < N, + (idx_base + 2) < N, + (idx_base + 3) < N + }; + uint64_t condition_offset[4], src1_offset[4], src2_offset[4]; + float4 reg_dst; + #pragma unroll + for (int i = 0; i < 4; ++i) { + if (valid[i]) { + uint64_t idx = idx_base + i; + condition_offset[i] = broadcast_map(idx, dst_shape, dst_ndim, condition_shape, + condition_strides, condition_ndim); + src1_offset[i] = broadcast_map(idx, dst_shape, dst_ndim, src1_shape, + src1_strides, src1_ndim); + src2_offset[i] = broadcast_map(idx, dst_shape, dst_ndim, src2_shape, + src2_strides, src2_ndim); + float val = condition[condition_offset[i]] ? src1[src1_offset[i]] : src2[src2_offset[i]]; + ((float*)®_dst)[i] = val; + } + } + if (valid[0] && valid[1] && valid[2] && valid[3]) { + FLOAT4(dst[idx_base]) = reg_dst; + } else { + for (int i = 0; i < 4; ++i) { + if (valid[i]) { + dst[idx_base + i] = ((float*)®_dst)[i]; + } + } + } +} + +__global__ void where_f16x8_kernel(const half* src1, uint64_t const *src1_shape, + const half* src2, uint64_t const *src2_shape, + half* dst, uint64_t const *dst_shape, + const uint8_t* condition, uint64_t const *condition_shape, + uint64_t dst_ndim, uint64_t src1_ndim, + uint64_t src2_ndim, uint64_t condition_ndim, + int64_t const *condition_strides, + int64_t const *src1_strides, + int64_t const *src2_strides, + int64_t const *dst_strides, + int N){ + int idx_base = 8 * (blockIdx.x * blockDim.x + threadIdx.x); + if (idx_base >= N) return; + int64_t condition_offset[8], src1_offset[8], src2_offset[8]; + half pack_dst[8]; + bool valid[8] = { + (idx_base + 0) < N, + (idx_base + 1) < N, + (idx_base + 2) < N, + (idx_base + 3) < N, + (idx_base + 4) < N, + (idx_base + 5) < N, + (idx_base + 6) < N, + (idx_base + 7) < N + }; + #pragma unroll + for (int i = 0; i < 8; i++){ + if (valid[i]){ + uint64_t idx = idx_base + i; + condition_offset[i] = broadcast_map(idx, dst_shape, dst_ndim, condition_shape, + condition_strides, condition_ndim); + src1_offset[i] = broadcast_map(idx, dst_shape, dst_ndim, src1_shape, + src1_strides, src1_ndim); + src2_offset[i] = broadcast_map(idx, dst_shape, dst_ndim, src2_shape, + src2_strides, src2_ndim); + pack_dst[i] = condition[condition_offset[i]] ? src1[src1_offset[i]] : src2[src2_offset[i]]; + } + } + if (valid[0] && valid[1] && valid[2] && valid[3] && valid[4] && valid[5] && valid[6] && valid[7]){ + LDST128BITS(dst[idx_base]) = LDST128BITS(pack_dst[0]); + } else { + for (int i = 0; i < 8; i++){ + if (valid[i]){ + dst[idx_base + i] = pack_dst[i]; + } + } + } +} + +template +infiniopStatus_t where_nv_gpu( + WhereCudaDescriptor_t desc, + void *dst, + void const *src1, + void const *src2, + void const *condition, + int per_thread_element, + void *stream){ + uint64_t N = desc->element_num; + dim3 block(256 / per_thread_element); + dim3 grid((N + 256 - 1) / 256); + if constexpr(std::is_same::value){ + where_f32x4_kernel<<>>( + reinterpret_cast(src1), desc->src1_shape, + reinterpret_cast(src2), desc->src2_shape, + reinterpret_cast(dst), desc->dst_shape, + reinterpret_cast(condition), desc->condition_shape, + desc->dst_ndim, desc->src1_ndim, + desc->src2_ndim, desc->condition_ndim, + desc->condition_strides, + desc->src1_strides, + desc->src2_strides, + desc->dst_strides, + N); + } else if constexpr(std::is_same::value){ + where_f16x8_kernel<<>>( + reinterpret_cast(src1), desc->src1_shape, + reinterpret_cast(src2), desc->src2_shape, + reinterpret_cast(dst), desc->dst_shape, + reinterpret_cast(condition), desc->condition_shape, + desc->dst_ndim, desc->src1_ndim, + desc->src2_ndim, desc->condition_ndim, + desc->condition_strides, + desc->src1_strides, + desc->src2_strides, + desc->dst_strides, + N); + } + return STATUS_SUCCESS; +} + + +infiniopStatus_t cudaWhere(WhereCudaDescriptor_t desc, + void *dst, + void const *src1, + void const *src2, + void const *condition, + void *stream){ + if (desc->dtype == F16){ + return where_nv_gpu(desc, dst, src1, src2, condition, 8, stream); + } + if (desc->dtype == F32){ + return where_nv_gpu(desc, dst, src1, src2, condition, 4, stream); + } + return STATUS_BAD_TENSOR_DTYPE; +} \ No newline at end of file diff --git a/src/ops/where/cuda/where_cuda.h b/src/ops/where/cuda/where_cuda.h new file mode 100644 index 00000000..5c2f6d8f --- /dev/null +++ b/src/ops/where/cuda/where_cuda.h @@ -0,0 +1,46 @@ +#ifndef __CUDA_WHERE_H__ +#define __CUDA_WHERE_H__ + +#include "../../../devices/cuda/cuda_handle.h" +#include "operators.h" +#include + +typedef struct WhereCudaDescriptor { + Device device; + DT dtype; + uint64_t const *src1_shape; + uint64_t const *src2_shape; + uint64_t const *condition_shape; + uint64_t const *dst_shape; + int64_t const *dst_strides; + int64_t const *src1_strides; + int64_t const *src2_strides; + int64_t const *condition_strides; + uint64_t dst_ndim; + uint64_t src1_ndim; + uint64_t src2_ndim; + uint64_t condition_ndim; + uint64_t element_num; +} WhereCudaDescriptor; + +typedef struct WhereCudaDescriptor *WhereCudaDescriptor_t; + +infiniopStatus_t cudaCreateWhereDescriptor(CudaHandle_t handle, + WhereCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t dst, + infiniopTensorDescriptor_t src1, + infiniopTensorDescriptor_t src2, + infiniopTensorDescriptor_t condition + ); + + +infiniopStatus_t cudaWhere(WhereCudaDescriptor_t desc, + void *dst, + void const *src1, + void const *src2, + void const *condition, + void *stream); + +infiniopStatus_t cudaDestroyWhereDescriptor(WhereCudaDescriptor_t desc); + +#endif diff --git a/src/ops/where/operator.cc b/src/ops/where/operator.cc new file mode 100644 index 00000000..512638d6 --- /dev/null +++ b/src/ops/where/operator.cc @@ -0,0 +1,63 @@ +#include "../utils.h" +#include "operators.h" +#include "ops/where/where.h" + +#ifdef ENABLE_CPU +#include "cpu/where_cpu.h" +#endif + +#ifdef ENABLE_NV_GPU +#include "cuda/where_cuda.h" +#endif + + +__C infiniopStatus_t infiniopCreateWhereDescriptor( + infiniopHandle_t handle, + infiniopWhereDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t dst, + infiniopTensorDescriptor_t src1, + infiniopTensorDescriptor_t src2, + infiniopTensorDescriptor_t condition + ) { + switch (handle->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuCreateWhereDescriptor(handle, (WhereCpuDescriptor_t *) desc_ptr, dst, src1, src2, condition); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaCreateWhereDescriptor((CudaHandle_t) handle, (WhereCudaDescriptor_t *) desc_ptr, dst, src1, src2, condition); + } +#endif + } + return STATUS_BAD_DEVICE; +} + + +__C infiniopStatus_t infiniopWhere(infiniopWhereDescriptor_t desc, void *dst, void const *src1, void const *src2, void const *condition, void *stream) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuWhere((WhereCpuDescriptor_t) desc, dst, src1, src2, condition, stream); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: + return cudaWhere((WhereCudaDescriptor_t) desc, dst, src1, src2, condition, stream); +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopDestroyWhereDescriptor(infiniopWhereDescriptor_t desc){ + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuDestroyWhereDescriptor((WhereCpuDescriptor_t) desc); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: + return cudaDestroyWhereDescriptor((WhereCudaDescriptor_t) desc); +#endif + } + return STATUS_BAD_DEVICE; +} \ No newline at end of file diff --git a/src/tensor/tensor_descriptor.cc b/src/tensor/tensor_descriptor.cc index 57afe92d..529a09fc 100644 --- a/src/tensor/tensor_descriptor.cc +++ b/src/tensor/tensor_descriptor.cc @@ -1,6 +1,7 @@ #include "tensor/tensor_descriptor.h" #include + __C __export infiniopStatus_t infiniopCreateTensorDescriptor(infiniopTensorDescriptor_t *desc_ptr, uint64_t ndim, uint64_t const *shape_, int64_t const *strides_, DataLayout datatype) { uint64_t *shape = new uint64_t[ndim]; int64_t *strides = new int64_t[ndim];