From 9d25304174238709f3673d47f8dfa414b96b77a4 Mon Sep 17 00:00:00 2001 From: zhushuang <974198603@qq.com> Date: Fri, 20 Dec 2024 16:35:12 +0800 Subject: [PATCH 1/4] add cpu concat --- include/ops/concat/concat.h | 32 ++++ operatorspy/tests/concat.py | 217 ++++++++++++++++++++++++ src/ops/concat/cpu/concat_cpu.cc | 282 +++++++++++++++++++++++++++++++ src/ops/concat/cpu/concat_cpu.h | 49 ++++++ src/ops/concat/operator.cc | 65 +++++++ 5 files changed, 645 insertions(+) create mode 100644 include/ops/concat/concat.h create mode 100644 operatorspy/tests/concat.py create mode 100644 src/ops/concat/cpu/concat_cpu.cc create mode 100644 src/ops/concat/cpu/concat_cpu.h create mode 100644 src/ops/concat/operator.cc diff --git a/include/ops/concat/concat.h b/include/ops/concat/concat.h new file mode 100644 index 00000000..2399a686 --- /dev/null +++ b/include/ops/concat/concat.h @@ -0,0 +1,32 @@ +#ifndef CONCAT_H +#define CONCAT_H + +#include "../../export.h" +#include "../../operators.h" + +// Concat描述符结构 +typedef struct ConcatDescriptor { + Device device; // 设备类型(例如 DevCpu、DevNvGpu) + uint64_t axis; // 拼接轴(从0开始) +} ConcatDescriptor; + +typedef ConcatDescriptor *infiniopConcatDescriptor_t; + +// 创建Concat描述符 +__C __export infiniopStatus_t infiniopCreateConcatDescriptor(infiniopHandle_t handle, + infiniopConcatDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t *x, + uint64_t num_inputs, + uint64_t axis); + +// 执行Concat操作 +__C __export infiniopStatus_t infiniopConcat(infiniopConcatDescriptor_t desc, + void *y, + void const **x, + void *stream); + +// 销毁Concat描述符 +__C __export infiniopStatus_t infiniopDestroyConcatDescriptor(infiniopConcatDescriptor_t desc); + +#endif diff --git a/operatorspy/tests/concat.py b/operatorspy/tests/concat.py new file mode 100644 index 00000000..0b8a214f --- /dev/null +++ b/operatorspy/tests/concat.py @@ -0,0 +1,217 @@ +from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64 +import ctypes +import sys +import os + +# 调整路径以导入 operatorspy 模块 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, +) + +from operatorspy.tests.test_utils import get_args +from enum import Enum, auto +import torch + + +class Inplace(Enum): + OUT_OF_PLACE = auto() + # 对于 concat 算子,通常不支持 in-place 操作,因此这里只保留 OUT_OF_PLACE + # 你可以根据实际需求扩展其他选项 + # INPLACE_A = auto() + # INPLACE_B = auto() + + +class ConcatDescriptor(Structure): + _fields_ = [("device", c_int32),] + + +infiniopConcatDescriptor_t = POINTER(ConcatDescriptor) + + +def concat_py(*tensors, dim=0): + """使用 PyTorch 进行拼接的辅助函数""" + return torch.cat(tensors, dim=dim) + + +def test( + lib, + handle, + torch_device, + c_shape, + axis, + input_shapes, + tensor_dtype=torch.float32, + inplace=Inplace.OUT_OF_PLACE, +): + """ + 测试 concat 算子 + """ + print( + f"Testing Concat on {torch_device} with output_shape:{c_shape}, input_shapes:{input_shapes}, axis:{axis}, dtype:{tensor_dtype}, inplace: {inplace.name}" + ) + + # 创建输入张量 + inputs = [torch.rand(shape, dtype=tensor_dtype).to(torch_device) for shape in input_shapes] + + for idx, tensor in enumerate(inputs): + print(f"Input {idx}:") + print(tensor) + print("-" * 50) + + # 创建输出张量 + if inplace == Inplace.OUT_OF_PLACE: + c = torch.zeros(c_shape, dtype=tensor_dtype).to(torch_device) + else: + # 对于 concat,通常不支持 in-place 操作,因此这里简化为 OUT_OF_PLACE + c = torch.zeros(c_shape, dtype=tensor_dtype).to(torch_device) + + # 使用 PyTorch 进行拼接,作为参考答案 + ans = concat_py(*inputs, dim=axis) + + print("ans:",ans) + print("-" * 50) + + # 将张量转换为 infiniop 所需的格式 + input_tensors = [to_tensor(t, lib) for t in inputs] + c_tensor = to_tensor(c, lib) if inplace == Inplace.OUT_OF_PLACE else to_tensor(c, lib) + + # 创建 Concat 描述符 + descriptor = infiniopConcatDescriptor_t() + + # 准备输入描述符数组 + num_inputs = len(input_tensors) + input_desc_array_type = infiniopTensorDescriptor_t * num_inputs + input_desc_array = input_desc_array_type(*[t.descriptor for t in input_tensors]) + + # 创建描述符 + check_error( + lib.infiniopCreateConcatDescriptor( + handle, + ctypes.byref(descriptor), + c_tensor.descriptor, # 使用 c_tensor 的描述符 + input_desc_array, # 输入张量描述符数组 + c_uint64(num_inputs), + c_uint64(axis), + ) + ) + + print("c1:",c) + print("-" * 50) + + # 执行拼接操作 + input_data_ptrs = (c_void_p * num_inputs)(*[t.data for t in input_tensors]) + check_error( + lib.infiniopConcat( + descriptor, + c_tensor.data, + ctypes.cast(input_data_ptrs, POINTER(c_void_p)), + None # 假设不需要流 + ) + ) + + print("c2:",c) + print("-" * 50) + + # 验证结果 + assert torch.allclose(c, ans, atol=0, rtol=1e-5), "Concat result does not match PyTorch's result." + + # 销毁描述符 + check_error(lib.infiniopDestroyConcatDescriptor(descriptor)) + + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for c_shape, axis, input_shapes, inplace in test_cases: + test(lib, handle, "cpu", c_shape, axis, input_shapes, inplace=inplace) + destroy_handle(lib, handle) + + +def test_cuda(lib, test_cases): + device = DeviceEnum.DEVICE_CUDA + handle = create_handle(lib, device) + for c_shape, axis, input_shapes, inplace in test_cases: + test(lib, handle, "cuda", c_shape, axis, input_shapes, inplace=inplace) + destroy_handle(lib, handle) + + +def test_bang(lib, test_cases): + import torch_mlu + + device = DeviceEnum.DEVICE_BANG + handle = create_handle(lib, device) + for c_shape, axis, input_shapes, inplace in test_cases: + test(lib, handle, "mlu", c_shape, axis, input_shapes, inplace=inplace) + destroy_handle(lib, handle) + + +if __name__ == "__main__": + # 定义测试用例 + test_cases = [ + # (output_shape, axis, input_shapes, inplace) + + ((6, 3), 0, [(2, 3), (4, 3)], Inplace.OUT_OF_PLACE), + # ((3, 6), 1, [(3, 2), (3, 4)], Inplace.OUT_OF_PLACE), + # ((3, 7), 1, [(3, 2), (3, 4), (3,1)], Inplace.OUT_OF_PLACE), + # ((3, 3, 10), 2, [(3, 3, 4), (3, 3, 6)], Inplace.OUT_OF_PLACE), + # ((1, 1), 0, [(1, 1)], Inplace.OUT_OF_PLACE), + # ((4, 5, 6), 0, [(1, 5, 6), (3, 5, 6)], Inplace.OUT_OF_PLACE), + # ((2, 3, 6), 2, [(2, 3, 2), (2, 3, 4)], Inplace.OUT_OF_PLACE), + + # 添加更多测试用例以覆盖不同的维度和拼接轴 + # ((2, 10, 3), 1, [(2, 5, 3), (2, 2, 3),(2,3,3)], Inplace.OUT_OF_PLACE), # 拼接沿第二维 + ] + + args = get_args() + lib = open_lib() + + # 绑定 C++ 函数 + # 创建 Concat 描述符 + lib.infiniopCreateConcatDescriptor.restype = c_int32 + lib.infiniopCreateConcatDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopConcatDescriptor_t), + infiniopTensorDescriptor_t, # 输出张量描述符 + POINTER(infiniopTensorDescriptor_t), # 输入张量描述符数组 + c_uint64, # 输入张量数量 + c_uint64, # 拼接轴 + ] + + # 执行 Concat + lib.infiniopConcat.restype = c_int32 + lib.infiniopConcat.argtypes = [ + infiniopConcatDescriptor_t, + c_void_p, # 输出数据指针 + POINTER(c_void_p), # 输入数据指针数组 + c_void_p, # 流(假设为 NULL) + ] + + # 销毁 Concat 描述符 + lib.infiniopDestroyConcatDescriptor.restype = c_int32 + lib.infiniopDestroyConcatDescriptor.argtypes = [ + infiniopConcatDescriptor_t, + ] + + # 根据命令行参数执行测试 + if args.cpu: + test_cpu(lib, test_cases) + if args.cuda: + test_cuda(lib, test_cases) + if args.bang: + test_bang(lib, test_cases) + if not (args.cpu or args.cuda or args.bang): + test_cpu(lib, test_cases) + + print("\033[92mConcat Test passed!\033[0m") + + + + diff --git a/src/ops/concat/cpu/concat_cpu.cc b/src/ops/concat/cpu/concat_cpu.cc new file mode 100644 index 00000000..3d7b8d95 --- /dev/null +++ b/src/ops/concat/cpu/concat_cpu.cc @@ -0,0 +1,282 @@ +#include "concat_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "../../utils.h" + + +infiniopStatus_t cpuCreateConcatDescriptor( + infiniopHandle_t handle, + ConcatCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t *x, + uint64_t num_inputs, + uint64_t axis) { + if (y == nullptr || x == nullptr || desc_ptr == nullptr || num_inputs == 0) { + return STATUS_BAD_PARAM; + } + + uint64_t ndim = y->ndim; // 输出张量维度 + if (axis >= ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + + uint64_t total_size = 0; // 拼接轴的总大小 + std::vector> input_shapes(num_inputs); // 输入张量形状 + std::vector> input_strides(num_inputs); // 输入张量步长 + + // 提取输出张量的形状和步长 + std::vector output_shape(y->shape, y->shape + ndim); + std::vector output_stride(y->strides, y->strides + ndim); + + // 验证输入张量的形状和步长,并记录形状信息 + for (size_t i = 0; i < num_inputs; ++i) { + + if (x[i]->dt != y->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + + if (x[i]->ndim != ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + + for (size_t j = 0; j < ndim; ++j) { + if (j != axis && x[i]->shape[j] != y->shape[j]) { + return STATUS_BAD_TENSOR_SHAPE; + } + } + // 记录每个输入张量的形状和步长 + input_shapes[i] = std::vector(x[i]->shape, x[i]->shape + ndim); + input_strides[i] = std::vector(x[i]->strides, x[i]->strides + ndim); + + // 累加拼接轴的总大小 + total_size += x[i]->shape[axis]; + } + + // 验证输出张量形状是否匹配 + if (total_size != y->shape[axis]) { + return STATUS_BAD_TENSOR_SHAPE; + } + + // 初始化Concat描述符 + *desc_ptr = new ConcatCpuDescriptor{ + DevCpu, + y->dt, + axis, + ndim, + num_inputs, + input_shapes, + input_strides, + output_shape, + output_stride + }; + + return STATUS_SUCCESS; +} + + +// 销毁Concat描述符 +infiniopStatus_t cpuDestroyConcatDescriptor(ConcatCpuDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} + + +// Helper function to handle different data types +template +infiniopStatus_t concatCompute(const ConcatCpuDescriptor_t& desc, + T* y, + void const** x) { + uint64_t axis = desc->axis; + uint64_t num_inputs = desc->num_inputs; + const std::vector> &input_shapes = desc->input_shapes; + const std::vector> &input_strides = desc->input_strides; + const std::vector &output_shape = desc->output_shape; + const std::vector &output_stride = desc->output_stride; + uint64_t ndim = desc->ndim; + + + // 计算拼接轴之前的总元素数(外层维度) + uint64_t outer_dim = 1; + for (uint64_t d = 0; d < axis; ++d) { + outer_dim *= output_shape[d]; + } + + // 计算拼接轴之后的总元素数(内层维度) + uint64_t inner_dim = 1; + for (uint64_t d = axis + 1; d < ndim; ++d) { + inner_dim *= output_shape[d]; + } + + // 计算每个输入张量在拼接轴上的偏移量 + std::vector dim_offsets(num_inputs, 0); + for (uint64_t i = 1; i < num_inputs; ++i) { + dim_offsets[i] = dim_offsets[i - 1] + input_shapes[i - 1][axis]; + } + + // 并行化外部循环 + #pragma omp parallel for + for (uint64_t od = 0; od < outer_dim; ++od) { + // 计算当前外层索引在各维度上的位置 + // 例如,如果外层维度为 [d0, d1, ..., d(axis-1)] + // 则 od 可以被分解为 d0 * (output_stride[0]) + d1 * (output_stride[1]) + ... + std::vector indices(ndim, 0); + uint64_t tmp = od; + for (uint64_t d = 0; d < axis; ++d) { + indices[d] = tmp / output_stride[d]; + tmp %= output_stride[d]; + } + + for (uint64_t i = 0; i < num_inputs; ++i) { + // 输入张量的拼接轴上的偏移 + uint64_t input_axis_offset = dim_offsets[i]; + + // 遍历拼接轴上的所有元素 + for (uint64_t a = 0; a < input_shapes[i][axis]; ++a) { + // 设置当前拼接轴的索引 + indices[axis] = a + input_axis_offset; + + // 计算输出张量的线性索引 + uint64_t y_offset = 0; + for (uint64_t d = 0; d < ndim; ++d) { + y_offset += indices[d] * output_stride[d]; + } + + // 计算输入张量的线性索引 + uint64_t x_offset = 0; + for (uint64_t d = 0; d < ndim; ++d) { + x_offset += indices[d] * input_strides[i][d]; + } + + // 复制数据 + y[y_offset] = reinterpret_cast(x[i])[x_offset]; + } + } + } + + return STATUS_SUCCESS; +} + +// 主拼接函数 +infiniopStatus_t cpuConcat(ConcatCpuDescriptor_t desc, + void *y, + void const **x, + void *stream) { + // 根据数据类型调用相应的模板实例 + switch (desc->dtype.size) { + case sizeof(float): // FLOAT32 + return concatCompute(desc, reinterpret_cast(y), x); + // 可以根据需要添加更多数据类型 + default: + return STATUS_SUCCESS; + } +} + + + + + +// infiniopStatus_t cpuConcat(ConcatCpuDescriptor_t desc, +// void *y, +// void const **x, +// void *stream) { +// // 从描述符中获取必要信息 +// uint64_t axis = desc->axis; // 拼接轴 +// uint64_t num_inputs = desc->num_inputs; // 输入张量数量 +// const std::vector> &input_shapes = desc->input_shapes; // 输入张量形状 +// const std::vector> &input_strides = desc->input_strides; // 输入张量步长 +// const std::vector &output_shape = desc->output_shape; // 输出张量形状 +// const std::vector &output_stride = desc->output_stride; // 输出张量步长 + +// DT dtype = desc->dtype; +// size_t element_size = dtype.size; +// uint64_t ndim = desc->ndim; + +// // 初始化累计偏移量,用于拼接轴的起始位置 +// uint64_t cumulative_axis_offset = 0; + +// // 遍历每个输入张量 +// for (uint64_t tensor_idx = 0; tensor_idx < num_inputs; ++tensor_idx) { +// const uint8_t *x_data = static_cast(x[tensor_idx]); +// const auto &x_shape = input_shapes[tensor_idx]; +// const auto &x_stride = input_strides[tensor_idx]; + +// uint64_t axis_size = x_shape[axis]; // 当前张量在拼接轴上的大小 + +// // 计算非拼接轴的遍历总数(用于确定每个块的偏移量) +// uint64_t outer_loops = 1; +// for (uint64_t i = 0; i < ndim; ++i) { +// if (i != axis) { +// outer_loops *= x_shape[i]; +// } +// } + +// uint64_t x_size=1; +// for(int i=0; i < ndim; i++){ +// x_size = x_size * x_shape[i]; +// } +// x_size = x_size * element_size; + +// // 遍历非拼接轴的所有元素块 +// for (uint64_t outer_idx = 0; outer_idx < outer_loops; ++outer_idx) { +// // 将线性索引转换为多维索引 +// std::vector indices(ndim, 0); +// linearToMultiDim(indices, outer_idx, x_shape, axis); + +// // 计算输入和输出张量的偏移量 +// uint64_t input_block_offset =computeOffset(indices, x_stride) * element_size; +// indices[axis] += cumulative_axis_offset; +// uint64_t output_block_offset =computeOffset(indices, output_stride) * element_size; + +// // 计算剩余空间 +// // uint64_t remaining_space_in_x = x_size - input_block_offset; + +// uint64_t remaining_space_in_x = (x_stride[axis] - input_block_offset % x_stride[axis]) * element_size; + +// printf("remaining_space_in_x: %llu\n",remaining_space_in_x); + +// // 计算当前块的大小,但不能超过剩余空间 +// uint64_t block_size = std::min(axis_size * x_stride[axis] * element_size, remaining_space_in_x); + + +// memcpy(static_cast(y) + output_block_offset, x_data + input_block_offset, block_size); +// } + +// // 更新累计偏移量 +// cumulative_axis_offset += axis_size; +// } + +// return STATUS_SUCCESS; + +// } + + +void linearToMultiDim(std::vector &indices, + uint64_t linear_idx, + const std::vector &shape, + uint64_t exclude_axis) { + uint64_t ndim = shape.size(); + for (int64_t dim = ndim - 1; dim >= 0; --dim) { + if (dim == exclude_axis) + continue; // 跳过拼接轴 + indices[dim] = linear_idx % shape[dim]; + linear_idx /= shape[dim]; + } +} + + +uint64_t computeOffset(const std::vector &indices, + const std::vector &stride) { + uint64_t offset = 0; + for (uint64_t dim = 0; dim < indices.size(); ++dim) { + offset += indices[dim] * stride[dim]; + } + return offset; +} + + + +// uint64_t remaining_space_in_x = (x_stride[axis] - input_block_offset % x_stride[axis]) * element_size; + +// printf("remaining_space_in_x: %llu\n",remaining_space_in_x); + +// // 每次拷贝当前块 +// uint64_t block_size = axis_size * x_stride[axis] * element_size; \ No newline at end of file diff --git a/src/ops/concat/cpu/concat_cpu.h b/src/ops/concat/cpu/concat_cpu.h new file mode 100644 index 00000000..8431f198 --- /dev/null +++ b/src/ops/concat/cpu/concat_cpu.h @@ -0,0 +1,49 @@ +#ifndef __CPU_CONCAT_H__ +#define __CPU_CONCAT_H__ +#include "operators.h" +#include +#include + +// 支持高维拼接的CPU-specific Concat描述符 +struct ConcatCpuDescriptor { + Device device; // 设备类型(例如 DevCpu) + DT dtype; // 数据类型 + uint64_t axis; // 拼接轴(从0开始) + uint64_t ndim; // 张量维度 + uint64_t num_inputs; // 输入张量的数量 + std::vector> input_shapes; // 输入张量的形状 + std::vector> input_strides; // 输入张量的步长 + std::vector output_shape; // 输出张量的形状 + std::vector output_stride; // 输出张量的步长 +}; + + + +typedef struct ConcatCpuDescriptor *ConcatCpuDescriptor_t; + +// 创建Concat描述符 +infiniopStatus_t cpuCreateConcatDescriptor(infiniopHandle_t handle, + ConcatCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t *x, + uint64_t num_inputs, + uint64_t axis); + +// 执行Concat操作 +infiniopStatus_t cpuConcat(ConcatCpuDescriptor_t desc, + void *y, + void const **x, + void *stream); + +// 销毁Concat描述符 +infiniopStatus_t cpuDestroyConcatDescriptor(ConcatCpuDescriptor_t desc); + +void linearToMultiDim(std::vector &indices, + uint64_t linear_idx, + const std::vector &shape, + uint64_t exclude_axis); + +uint64_t computeOffset(const std::vector &indices, + const std::vector &stride); + +#endif diff --git a/src/ops/concat/operator.cc b/src/ops/concat/operator.cc new file mode 100644 index 00000000..e99d5de8 --- /dev/null +++ b/src/ops/concat/operator.cc @@ -0,0 +1,65 @@ +#include "../utils.h" +#include "operators.h" +#include "ops/concat/concat.h" + +#ifdef ENABLE_CPU +#include "cpu/concat_cpu.h" +#endif +#ifdef ENABLE_NV_GPU +#include "../../devices/cuda/cuda_handle.h" +#include "cuda/concat.cuh" +#endif + +// 创建Concat描述符 +__C infiniopStatus_t infiniopCreateConcatDescriptor( + infiniopHandle_t handle, + infiniopConcatDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t *x, + uint64_t num_inputs, + uint64_t axis) { + switch (handle->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuCreateConcatDescriptor(handle, (ConcatCpuDescriptor_t *) desc_ptr, y, x, num_inputs, axis); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaCreateConcatDescriptor((CudaHandle_t) handle, (ConcatCudaDescriptor_t *) desc_ptr, y, x, num_inputs, axis); + } +#endif + } + return STATUS_BAD_DEVICE; +} + +// 执行Concat操作 +__C infiniopStatus_t infiniopConcat(infiniopConcatDescriptor_t desc, void *y, void const **x, void *stream) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuConcat((ConcatCpuDescriptor_t) desc, y, x, stream); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaConcat((ConcatCudaDescriptor_t) desc, y, x, stream); + } +#endif + } + return STATUS_BAD_DEVICE; +} + +// 销毁Concat描述符 +__C infiniopStatus_t infiniopDestroyConcatDescriptor(infiniopConcatDescriptor_t desc) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuDestroyConcatDescriptor((ConcatCpuDescriptor_t) desc); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaDestroyConcatDescriptor((ConcatCudaDescriptor_t) desc); + } +#endif + } + return STATUS_BAD_DEVICE; +} From 6933d3887b4fffd9cb51cf9a05f4f088af07f25a Mon Sep 17 00:00:00 2001 From: zhushuang <974198603@qq.com> Date: Mon, 23 Dec 2024 18:51:40 +0800 Subject: [PATCH 2/4] feat(cpu): support concat negative axis --- include/ops/concat/concat.h | 11 +- operatorspy/liboperators.py | 3 - operatorspy/tests/concat.py | 123 ++++++++-------- src/ops/concat/cpu/concat_cpu.cc | 236 ++++++------------------------- src/ops/concat/cpu/concat_cpu.h | 31 +--- src/ops/concat/operator.cc | 7 +- 6 files changed, 120 insertions(+), 291 deletions(-) diff --git a/include/ops/concat/concat.h b/include/ops/concat/concat.h index 2399a686..20ca6339 100644 --- a/include/ops/concat/concat.h +++ b/include/ops/concat/concat.h @@ -4,29 +4,24 @@ #include "../../export.h" #include "../../operators.h" -// Concat描述符结构 typedef struct ConcatDescriptor { - Device device; // 设备类型(例如 DevCpu、DevNvGpu) - uint64_t axis; // 拼接轴(从0开始) + Device device; } ConcatDescriptor; typedef ConcatDescriptor *infiniopConcatDescriptor_t; -// 创建Concat描述符 __C __export infiniopStatus_t infiniopCreateConcatDescriptor(infiniopHandle_t handle, infiniopConcatDescriptor_t *desc_ptr, infiniopTensorDescriptor_t y, infiniopTensorDescriptor_t *x, uint64_t num_inputs, - uint64_t axis); + int64_t axis); -// 执行Concat操作 __C __export infiniopStatus_t infiniopConcat(infiniopConcatDescriptor_t desc, void *y, void const **x, void *stream); - -// 销毁Concat描述符 + __C __export infiniopStatus_t infiniopDestroyConcatDescriptor(infiniopConcatDescriptor_t desc); #endif diff --git a/operatorspy/liboperators.py b/operatorspy/liboperators.py index 868cc88d..fb58d6a7 100644 --- a/operatorspy/liboperators.py +++ b/operatorspy/liboperators.py @@ -10,7 +10,6 @@ LIB_OPERATORS_DIR = os.path.join(os.environ.get("INFINI_ROOT"), "lib") - class TensorDescriptor(Structure): _fields_ = [ ("dt", DataLayout), @@ -19,10 +18,8 @@ class TensorDescriptor(Structure): ("pattern", POINTER(c_int64)), ] - infiniopTensorDescriptor_t = ctypes.POINTER(TensorDescriptor) - class CTensor: def __init__(self, desc, data): self.descriptor = desc diff --git a/operatorspy/tests/concat.py b/operatorspy/tests/concat.py index 0b8a214f..f5cccba4 100644 --- a/operatorspy/tests/concat.py +++ b/operatorspy/tests/concat.py @@ -1,9 +1,8 @@ -from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64 +from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64, c_int64 import ctypes import sys import os -# 调整路径以导入 operatorspy 模块 sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) from operatorspy import ( open_lib, @@ -23,11 +22,6 @@ class Inplace(Enum): OUT_OF_PLACE = auto() - # 对于 concat 算子,通常不支持 in-place 操作,因此这里只保留 OUT_OF_PLACE - # 你可以根据实际需求扩展其他选项 - # INPLACE_A = auto() - # INPLACE_B = auto() - class ConcatDescriptor(Structure): _fields_ = [("device", c_int32),] @@ -37,7 +31,6 @@ class ConcatDescriptor(Structure): def concat_py(*tensors, dim=0): - """使用 PyTorch 进行拼接的辅助函数""" return torch.cat(tensors, dim=dim) @@ -58,7 +51,6 @@ def test( f"Testing Concat on {torch_device} with output_shape:{c_shape}, input_shapes:{input_shapes}, axis:{axis}, dtype:{tensor_dtype}, inplace: {inplace.name}" ) - # 创建输入张量 inputs = [torch.rand(shape, dtype=tensor_dtype).to(torch_device) for shape in input_shapes] for idx, tensor in enumerate(inputs): @@ -66,64 +58,45 @@ def test( print(tensor) print("-" * 50) - # 创建输出张量 if inplace == Inplace.OUT_OF_PLACE: c = torch.zeros(c_shape, dtype=tensor_dtype).to(torch_device) else: - # 对于 concat,通常不支持 in-place 操作,因此这里简化为 OUT_OF_PLACE c = torch.zeros(c_shape, dtype=tensor_dtype).to(torch_device) - # 使用 PyTorch 进行拼接,作为参考答案 ans = concat_py(*inputs, dim=axis) - - print("ans:",ans) - print("-" * 50) - # 将张量转换为 infiniop 所需的格式 input_tensors = [to_tensor(t, lib) for t in inputs] c_tensor = to_tensor(c, lib) if inplace == Inplace.OUT_OF_PLACE else to_tensor(c, lib) - # 创建 Concat 描述符 descriptor = infiniopConcatDescriptor_t() - - # 准备输入描述符数组 + num_inputs = len(input_tensors) input_desc_array_type = infiniopTensorDescriptor_t * num_inputs input_desc_array = input_desc_array_type(*[t.descriptor for t in input_tensors]) - # 创建描述符 check_error( lib.infiniopCreateConcatDescriptor( handle, ctypes.byref(descriptor), - c_tensor.descriptor, # 使用 c_tensor 的描述符 - input_desc_array, # 输入张量描述符数组 + c_tensor.descriptor, + input_desc_array, c_uint64(num_inputs), - c_uint64(axis), + c_int64(axis), ) ) - print("c1:",c) - print("-" * 50) - - # 执行拼接操作 input_data_ptrs = (c_void_p * num_inputs)(*[t.data for t in input_tensors]) check_error( lib.infiniopConcat( descriptor, c_tensor.data, ctypes.cast(input_data_ptrs, POINTER(c_void_p)), - None # 假设不需要流 + None ) ) - - print("c2:",c) - print("-" * 50) - # 验证结果 - assert torch.allclose(c, ans, atol=0, rtol=1e-5), "Concat result does not match PyTorch's result." + assert torch.allclose(c, ans, atol=0, rtol=0), "Concat result does not match PyTorch's result." - # 销毁描述符 check_error(lib.infiniopDestroyConcatDescriptor(descriptor)) @@ -154,53 +127,85 @@ def test_bang(lib, test_cases): if __name__ == "__main__": - # 定义测试用例 + test_cases = [ - # (output_shape, axis, input_shapes, inplace) - - ((6, 3), 0, [(2, 3), (4, 3)], Inplace.OUT_OF_PLACE), - # ((3, 6), 1, [(3, 2), (3, 4)], Inplace.OUT_OF_PLACE), - # ((3, 7), 1, [(3, 2), (3, 4), (3,1)], Inplace.OUT_OF_PLACE), - # ((3, 3, 10), 2, [(3, 3, 4), (3, 3, 6)], Inplace.OUT_OF_PLACE), - # ((1, 1), 0, [(1, 1)], Inplace.OUT_OF_PLACE), - # ((4, 5, 6), 0, [(1, 5, 6), (3, 5, 6)], Inplace.OUT_OF_PLACE), - # ((2, 3, 6), 2, [(2, 3, 2), (2, 3, 4)], Inplace.OUT_OF_PLACE), - - # 添加更多测试用例以覆盖不同的维度和拼接轴 - # ((2, 10, 3), 1, [(2, 5, 3), (2, 2, 3),(2,3,3)], Inplace.OUT_OF_PLACE), # 拼接沿第二维 + + ((6,), 0, [(2,), (4,)], Inplace.OUT_OF_PLACE), + + ((6, 3), 0, [(2, 3), (4, 3)], Inplace.OUT_OF_PLACE), + ((3, 6), 1, [(3, 2), (3, 4)], Inplace.OUT_OF_PLACE), + ((3, 7), 1, [(3, 2), (3, 4), (3, 1)], Inplace.OUT_OF_PLACE), + ((3, 3, 10), 2, [(3, 3, 4), (3, 3, 6)], Inplace.OUT_OF_PLACE), + + ((4, 3, 6), 0, [(3, 3, 6), (1, 3, 6)], Inplace.OUT_OF_PLACE), + ((2, 6, 3), 1, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE), + ((2, 3, 6), 2, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE), + + ((4, 3, 5, 6), 0, [(1, 3, 5, 6), (3, 3, 5, 6)], Inplace.OUT_OF_PLACE), + ((2, 5, 5, 6), 1, [(2, 3, 5, 6), (2, 2, 5, 6)], Inplace.OUT_OF_PLACE), + ((2, 3, 5, 6), 2, [(2, 3, 2, 6), (2, 3, 3, 6)], Inplace.OUT_OF_PLACE), + ((2, 3, 5, 6), 3, [(2, 3, 5, 3), (2, 3, 5, 3)], Inplace.OUT_OF_PLACE), + ((2, 3, 5, 15), 3, [(2, 3, 5, 3), (2, 3, 5, 3), (2, 3, 5, 9)], Inplace.OUT_OF_PLACE), + + ((4, 2, 3, 4, 5), 0, [(1, 2, 3, 4, 5), (3, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE), + ((2, 4, 3, 2, 5), 1, [(2, 2, 3, 2, 5), (2, 2, 3, 2, 5)], Inplace.OUT_OF_PLACE), + ((1, 2, 4, 4, 5), 2, [(1, 2, 2, 4, 5), (1, 2, 2, 4, 5)], Inplace.OUT_OF_PLACE), + ((1, 2, 3, 8, 5), 3, [(1, 2, 3, 4, 5), (1, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE), + ((1, 2, 3, 4, 5), 4, [(1, 2, 3, 4, 3), (1, 2, 3, 4, 2)], Inplace.OUT_OF_PLACE), + ((4, 14, 3, 4, 5), 1, [(4, 3, 3, 4, 5), (4, 5, 3, 4, 5), (4, 6, 3, 4, 5)], Inplace.OUT_OF_PLACE), + + + ((6,), -1, [(2,), (4,)], Inplace.OUT_OF_PLACE), + + ((6, 3), -2, [(2, 3), (4, 3)], Inplace.OUT_OF_PLACE), + ((3, 6), -1, [(3, 2), (3, 4)], Inplace.OUT_OF_PLACE), + ((3, 7), -1, [(3, 2), (3, 4), (3, 1)], Inplace.OUT_OF_PLACE), + ((3, 3, 10), -1, [(3, 3, 4), (3, 3, 6)], Inplace.OUT_OF_PLACE), + + ((4, 3, 6), -3, [(3, 3, 6), (1, 3, 6)], Inplace.OUT_OF_PLACE), + ((2, 6, 3), -2, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE), + ((2, 3, 6), -1, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE), + + ((4, 3, 5, 6), -4, [(1, 3, 5, 6), (3, 3, 5, 6)], Inplace.OUT_OF_PLACE), + ((2, 5, 5, 6), -3, [(2, 3, 5, 6), (2, 2, 5, 6)], Inplace.OUT_OF_PLACE), + ((2, 3, 5, 6), -2, [(2, 3, 2, 6), (2, 3, 3, 6)], Inplace.OUT_OF_PLACE), + ((2, 3, 5, 6), -1, [(2, 3, 5, 3), (2, 3, 5, 3)], Inplace.OUT_OF_PLACE), + ((2, 3, 5, 15), -1, [(2, 3, 5, 3), (2, 3, 5, 3), (2, 3, 5, 9)], Inplace.OUT_OF_PLACE), + + ((4, 2, 3, 4, 5), -5, [(1, 2, 3, 4, 5), (3, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE), + ((2, 4, 3, 2, 5), -4, [(2, 2, 3, 2, 5), (2, 2, 3, 2, 5)], Inplace.OUT_OF_PLACE), + ((1, 2, 4, 4, 5), -3, [(1, 2, 2, 4, 5), (1, 2, 2, 4, 5)], Inplace.OUT_OF_PLACE), + ((1, 2, 3, 8, 5), -2, [(1, 2, 3, 4, 5), (1, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE), + ((1, 2, 3, 4, 5), -1, [(1, 2, 3, 4, 3), (1, 2, 3, 4, 2)], Inplace.OUT_OF_PLACE), + ((4, 14, 3, 4, 5), -4, [(4, 3, 3, 4, 5), (4, 5, 3, 4, 5), (4, 6, 3, 4, 5)], Inplace.OUT_OF_PLACE), ] args = get_args() lib = open_lib() - # 绑定 C++ 函数 - # 创建 Concat 描述符 lib.infiniopCreateConcatDescriptor.restype = c_int32 lib.infiniopCreateConcatDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopConcatDescriptor_t), - infiniopTensorDescriptor_t, # 输出张量描述符 - POINTER(infiniopTensorDescriptor_t), # 输入张量描述符数组 - c_uint64, # 输入张量数量 - c_uint64, # 拼接轴 + infiniopTensorDescriptor_t, + POINTER(infiniopTensorDescriptor_t), + c_uint64, # nums_input + c_int64, # axis ] - # 执行 Concat lib.infiniopConcat.restype = c_int32 lib.infiniopConcat.argtypes = [ infiniopConcatDescriptor_t, - c_void_p, # 输出数据指针 - POINTER(c_void_p), # 输入数据指针数组 - c_void_p, # 流(假设为 NULL) + c_void_p, + POINTER(c_void_p), + c_void_p, ] - # 销毁 Concat 描述符 lib.infiniopDestroyConcatDescriptor.restype = c_int32 lib.infiniopDestroyConcatDescriptor.argtypes = [ infiniopConcatDescriptor_t, ] - # 根据命令行参数执行测试 if args.cpu: test_cpu(lib, test_cases) if args.cuda: diff --git a/src/ops/concat/cpu/concat_cpu.cc b/src/ops/concat/cpu/concat_cpu.cc index 3d7b8d95..0f7dce72 100644 --- a/src/ops/concat/cpu/concat_cpu.cc +++ b/src/ops/concat/cpu/concat_cpu.cc @@ -2,32 +2,31 @@ #include "../../../devices/cpu/common_cpu.h" #include "../../utils.h" - infiniopStatus_t cpuCreateConcatDescriptor( infiniopHandle_t handle, ConcatCpuDescriptor_t *desc_ptr, infiniopTensorDescriptor_t y, infiniopTensorDescriptor_t *x, uint64_t num_inputs, - uint64_t axis) { + int64_t axis) { if (y == nullptr || x == nullptr || desc_ptr == nullptr || num_inputs == 0) { return STATUS_BAD_PARAM; } - uint64_t ndim = y->ndim; // 输出张量维度 - if (axis >= ndim) { - return STATUS_BAD_TENSOR_SHAPE; + int64_t ndim = y->ndim; + if (axis >= ndim || axis < -ndim) { + return STATUS_BAD_PARAM; } - uint64_t total_size = 0; // 拼接轴的总大小 - std::vector> input_shapes(num_inputs); // 输入张量形状 - std::vector> input_strides(num_inputs); // 输入张量步长 + if(axis < 0){ + axis = axis + ndim; + } + + uint64_t total_size = 0; + std::vector> input_shapes(num_inputs); - // 提取输出张量的形状和步长 std::vector output_shape(y->shape, y->shape + ndim); - std::vector output_stride(y->strides, y->strides + ndim); - // 验证输入张量的形状和步长,并记录形状信息 for (size_t i = 0; i < num_inputs; ++i) { if (x[i]->dt != y->dt) { @@ -37,246 +36,97 @@ infiniopStatus_t cpuCreateConcatDescriptor( if (x[i]->ndim != ndim) { return STATUS_BAD_TENSOR_SHAPE; } - + for (size_t j = 0; j < ndim; ++j) { if (j != axis && x[i]->shape[j] != y->shape[j]) { return STATUS_BAD_TENSOR_SHAPE; } } - // 记录每个输入张量的形状和步长 - input_shapes[i] = std::vector(x[i]->shape, x[i]->shape + ndim); - input_strides[i] = std::vector(x[i]->strides, x[i]->strides + ndim); - // 累加拼接轴的总大小 + input_shapes[i] = std::vector(x[i]->shape, x[i]->shape + ndim); total_size += x[i]->shape[axis]; } - // 验证输出张量形状是否匹配 if (total_size != y->shape[axis]) { return STATUS_BAD_TENSOR_SHAPE; } - // 初始化Concat描述符 *desc_ptr = new ConcatCpuDescriptor{ DevCpu, y->dt, axis, - ndim, num_inputs, input_shapes, - input_strides, output_shape, - output_stride }; return STATUS_SUCCESS; } - -// 销毁Concat描述符 infiniopStatus_t cpuDestroyConcatDescriptor(ConcatCpuDescriptor_t desc) { delete desc; return STATUS_SUCCESS; } - -// Helper function to handle different data types template infiniopStatus_t concatCompute(const ConcatCpuDescriptor_t& desc, T* y, void const** x) { - uint64_t axis = desc->axis; + int64_t axis = desc->axis; uint64_t num_inputs = desc->num_inputs; - const std::vector> &input_shapes = desc->input_shapes; - const std::vector> &input_strides = desc->input_strides; - const std::vector &output_shape = desc->output_shape; - const std::vector &output_stride = desc->output_stride; - uint64_t ndim = desc->ndim; - + const std::vector>& input_shapes = desc->input_shapes; + const std::vector& output_shape = desc->output_shape; - // 计算拼接轴之前的总元素数(外层维度) - uint64_t outer_dim = 1; - for (uint64_t d = 0; d < axis; ++d) { - outer_dim *= output_shape[d]; + size_t blockOffsetInner = 1; + for (size_t i = output_shape.size() - 1; i > axis; --i) { + blockOffsetInner *= output_shape[i]; } + size_t blockOffset = output_shape[axis] * blockOffsetInner; - // 计算拼接轴之后的总元素数(内层维度) - uint64_t inner_dim = 1; - for (uint64_t d = axis + 1; d < ndim; ++d) { - inner_dim *= output_shape[d]; - } - - // 计算每个输入张量在拼接轴上的偏移量 - std::vector dim_offsets(num_inputs, 0); - for (uint64_t i = 1; i < num_inputs; ++i) { - dim_offsets[i] = dim_offsets[i - 1] + input_shapes[i - 1][axis]; - } + for (size_t i = 0; i < num_inputs; ++i) { + const std::vector& input_shape = input_shapes[i]; - // 并行化外部循环 - #pragma omp parallel for - for (uint64_t od = 0; od < outer_dim; ++od) { - // 计算当前外层索引在各维度上的位置 - // 例如,如果外层维度为 [d0, d1, ..., d(axis-1)] - // 则 od 可以被分解为 d0 * (output_stride[0]) + d1 * (output_stride[1]) + ... - std::vector indices(ndim, 0); - uint64_t tmp = od; - for (uint64_t d = 0; d < axis; ++d) { - indices[d] = tmp / output_stride[d]; - tmp %= output_stride[d]; + size_t dimOffset = 0; + for (size_t j = 0; j < i; ++j) { + dimOffset += input_shapes[j][axis]; } - for (uint64_t i = 0; i < num_inputs; ++i) { - // 输入张量的拼接轴上的偏移 - uint64_t input_axis_offset = dim_offsets[i]; + size_t localBlockOffset = 1; + for (size_t j = input_shape.size() - 1; j >= axis && j != static_cast(-1); --j) { + localBlockOffset *= input_shape[j]; + } + + size_t innerOffset = blockOffsetInner * dimOffset; + size_t inSize = 1; + for (auto dim : input_shape) { + inSize *= dim; + } - // 遍历拼接轴上的所有元素 - for (uint64_t a = 0; a < input_shapes[i][axis]; ++a) { - // 设置当前拼接轴的索引 - indices[axis] = a + input_axis_offset; + T* input_data = static_cast(const_cast(x[i])); - // 计算输出张量的线性索引 - uint64_t y_offset = 0; - for (uint64_t d = 0; d < ndim; ++d) { - y_offset += indices[d] * output_stride[d]; - } + #pragma omp parallel for + for (size_t iOffset = 0; iOffset < inSize; ++iOffset) { - // 计算输入张量的线性索引 - uint64_t x_offset = 0; - for (uint64_t d = 0; d < ndim; ++d) { - x_offset += indices[d] * input_strides[i][d]; - } + size_t oOffset = iOffset % localBlockOffset + innerOffset + + iOffset / localBlockOffset * blockOffset; - // 复制数据 - y[y_offset] = reinterpret_cast(x[i])[x_offset]; - } + y[oOffset] = input_data[iOffset]; } } - return STATUS_SUCCESS; + return STATUS_SUCCESS; } -// 主拼接函数 infiniopStatus_t cpuConcat(ConcatCpuDescriptor_t desc, void *y, void const **x, void *stream) { - // 根据数据类型调用相应的模板实例 + switch (desc->dtype.size) { case sizeof(float): // FLOAT32 return concatCompute(desc, reinterpret_cast(y), x); - // 可以根据需要添加更多数据类型 + // add other data.type default: return STATUS_SUCCESS; } } - - - - - -// infiniopStatus_t cpuConcat(ConcatCpuDescriptor_t desc, -// void *y, -// void const **x, -// void *stream) { -// // 从描述符中获取必要信息 -// uint64_t axis = desc->axis; // 拼接轴 -// uint64_t num_inputs = desc->num_inputs; // 输入张量数量 -// const std::vector> &input_shapes = desc->input_shapes; // 输入张量形状 -// const std::vector> &input_strides = desc->input_strides; // 输入张量步长 -// const std::vector &output_shape = desc->output_shape; // 输出张量形状 -// const std::vector &output_stride = desc->output_stride; // 输出张量步长 - -// DT dtype = desc->dtype; -// size_t element_size = dtype.size; -// uint64_t ndim = desc->ndim; - -// // 初始化累计偏移量,用于拼接轴的起始位置 -// uint64_t cumulative_axis_offset = 0; - -// // 遍历每个输入张量 -// for (uint64_t tensor_idx = 0; tensor_idx < num_inputs; ++tensor_idx) { -// const uint8_t *x_data = static_cast(x[tensor_idx]); -// const auto &x_shape = input_shapes[tensor_idx]; -// const auto &x_stride = input_strides[tensor_idx]; - -// uint64_t axis_size = x_shape[axis]; // 当前张量在拼接轴上的大小 - -// // 计算非拼接轴的遍历总数(用于确定每个块的偏移量) -// uint64_t outer_loops = 1; -// for (uint64_t i = 0; i < ndim; ++i) { -// if (i != axis) { -// outer_loops *= x_shape[i]; -// } -// } - -// uint64_t x_size=1; -// for(int i=0; i < ndim; i++){ -// x_size = x_size * x_shape[i]; -// } -// x_size = x_size * element_size; - -// // 遍历非拼接轴的所有元素块 -// for (uint64_t outer_idx = 0; outer_idx < outer_loops; ++outer_idx) { -// // 将线性索引转换为多维索引 -// std::vector indices(ndim, 0); -// linearToMultiDim(indices, outer_idx, x_shape, axis); - -// // 计算输入和输出张量的偏移量 -// uint64_t input_block_offset =computeOffset(indices, x_stride) * element_size; -// indices[axis] += cumulative_axis_offset; -// uint64_t output_block_offset =computeOffset(indices, output_stride) * element_size; - -// // 计算剩余空间 -// // uint64_t remaining_space_in_x = x_size - input_block_offset; - -// uint64_t remaining_space_in_x = (x_stride[axis] - input_block_offset % x_stride[axis]) * element_size; - -// printf("remaining_space_in_x: %llu\n",remaining_space_in_x); - -// // 计算当前块的大小,但不能超过剩余空间 -// uint64_t block_size = std::min(axis_size * x_stride[axis] * element_size, remaining_space_in_x); - - -// memcpy(static_cast(y) + output_block_offset, x_data + input_block_offset, block_size); -// } - -// // 更新累计偏移量 -// cumulative_axis_offset += axis_size; -// } - -// return STATUS_SUCCESS; - -// } - - -void linearToMultiDim(std::vector &indices, - uint64_t linear_idx, - const std::vector &shape, - uint64_t exclude_axis) { - uint64_t ndim = shape.size(); - for (int64_t dim = ndim - 1; dim >= 0; --dim) { - if (dim == exclude_axis) - continue; // 跳过拼接轴 - indices[dim] = linear_idx % shape[dim]; - linear_idx /= shape[dim]; - } -} - - -uint64_t computeOffset(const std::vector &indices, - const std::vector &stride) { - uint64_t offset = 0; - for (uint64_t dim = 0; dim < indices.size(); ++dim) { - offset += indices[dim] * stride[dim]; - } - return offset; -} - - - -// uint64_t remaining_space_in_x = (x_stride[axis] - input_block_offset % x_stride[axis]) * element_size; - -// printf("remaining_space_in_x: %llu\n",remaining_space_in_x); - -// // 每次拷贝当前块 -// uint64_t block_size = axis_size * x_stride[axis] * element_size; \ No newline at end of file diff --git a/src/ops/concat/cpu/concat_cpu.h b/src/ops/concat/cpu/concat_cpu.h index 8431f198..a8d4d71d 100644 --- a/src/ops/concat/cpu/concat_cpu.h +++ b/src/ops/concat/cpu/concat_cpu.h @@ -4,46 +4,29 @@ #include #include -// 支持高维拼接的CPU-specific Concat描述符 struct ConcatCpuDescriptor { - Device device; // 设备类型(例如 DevCpu) - DT dtype; // 数据类型 - uint64_t axis; // 拼接轴(从0开始) - uint64_t ndim; // 张量维度 - uint64_t num_inputs; // 输入张量的数量 - std::vector> input_shapes; // 输入张量的形状 - std::vector> input_strides; // 输入张量的步长 - std::vector output_shape; // 输出张量的形状 - std::vector output_stride; // 输出张量的步长 + Device device; + DT dtype; + int64_t axis; + uint64_t num_inputs; + std::vector> input_shapes; + std::vector output_shape; }; - - typedef struct ConcatCpuDescriptor *ConcatCpuDescriptor_t; -// 创建Concat描述符 infiniopStatus_t cpuCreateConcatDescriptor(infiniopHandle_t handle, ConcatCpuDescriptor_t *desc_ptr, infiniopTensorDescriptor_t y, infiniopTensorDescriptor_t *x, uint64_t num_inputs, - uint64_t axis); + int64_t axis); -// 执行Concat操作 infiniopStatus_t cpuConcat(ConcatCpuDescriptor_t desc, void *y, void const **x, void *stream); -// 销毁Concat描述符 infiniopStatus_t cpuDestroyConcatDescriptor(ConcatCpuDescriptor_t desc); -void linearToMultiDim(std::vector &indices, - uint64_t linear_idx, - const std::vector &shape, - uint64_t exclude_axis); - -uint64_t computeOffset(const std::vector &indices, - const std::vector &stride); - #endif diff --git a/src/ops/concat/operator.cc b/src/ops/concat/operator.cc index e99d5de8..5f3cdae1 100644 --- a/src/ops/concat/operator.cc +++ b/src/ops/concat/operator.cc @@ -10,14 +10,13 @@ #include "cuda/concat.cuh" #endif -// 创建Concat描述符 __C infiniopStatus_t infiniopCreateConcatDescriptor( infiniopHandle_t handle, infiniopConcatDescriptor_t *desc_ptr, infiniopTensorDescriptor_t y, infiniopTensorDescriptor_t *x, uint64_t num_inputs, - uint64_t axis) { + int64_t axis) { switch (handle->device) { #ifdef ENABLE_CPU case DevCpu: @@ -32,7 +31,7 @@ __C infiniopStatus_t infiniopCreateConcatDescriptor( return STATUS_BAD_DEVICE; } -// 执行Concat操作 + __C infiniopStatus_t infiniopConcat(infiniopConcatDescriptor_t desc, void *y, void const **x, void *stream) { switch (desc->device) { #ifdef ENABLE_CPU @@ -48,7 +47,7 @@ __C infiniopStatus_t infiniopConcat(infiniopConcatDescriptor_t desc, void *y, vo return STATUS_BAD_DEVICE; } -// 销毁Concat描述符 + __C infiniopStatus_t infiniopDestroyConcatDescriptor(infiniopConcatDescriptor_t desc) { switch (desc->device) { #ifdef ENABLE_CPU From 2870b4e2edeed7df977010d2cd26317a87ac5b6c Mon Sep 17 00:00:00 2001 From: zhushuang <974198603@qq.com> Date: Tue, 24 Dec 2024 09:51:18 +0800 Subject: [PATCH 3/4] update2 --- operatorspy/tests/concat.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/operatorspy/tests/concat.py b/operatorspy/tests/concat.py index f5cccba4..fb3baad8 100644 --- a/operatorspy/tests/concat.py +++ b/operatorspy/tests/concat.py @@ -52,11 +52,6 @@ def test( ) inputs = [torch.rand(shape, dtype=tensor_dtype).to(torch_device) for shape in input_shapes] - - for idx, tensor in enumerate(inputs): - print(f"Input {idx}:") - print(tensor) - print("-" * 50) if inplace == Inplace.OUT_OF_PLACE: c = torch.zeros(c_shape, dtype=tensor_dtype).to(torch_device) From b63439716f7a16cd7c7853187233bba62f08b0fe Mon Sep 17 00:00:00 2001 From: zhushuang <974198603@qq.com> Date: Thu, 2 Jan 2025 15:55:53 +0800 Subject: [PATCH 4/4] feat: add CUDA support to concat --- operatorspy/tests/concat.py | 17 +++---- src/ops/concat/cpu/concat_cpu.cc | 21 +++++--- src/ops/concat/cuda/concat.cc | 73 +++++++++++++++++++++++++++ src/ops/concat/cuda/concat.cu | 86 ++++++++++++++++++++++++++++++++ src/ops/concat/cuda/concat.cuh | 36 +++++++++++++ 5 files changed, 215 insertions(+), 18 deletions(-) create mode 100644 src/ops/concat/cuda/concat.cc create mode 100644 src/ops/concat/cuda/concat.cu create mode 100644 src/ops/concat/cuda/concat.cuh diff --git a/operatorspy/tests/concat.py b/operatorspy/tests/concat.py index fb3baad8..96f34088 100644 --- a/operatorspy/tests/concat.py +++ b/operatorspy/tests/concat.py @@ -99,7 +99,8 @@ def test_cpu(lib, test_cases): device = DeviceEnum.DEVICE_CPU handle = create_handle(lib, device) for c_shape, axis, input_shapes, inplace in test_cases: - test(lib, handle, "cpu", c_shape, axis, input_shapes, inplace=inplace) + test(lib, handle, "cpu", c_shape, axis, input_shapes, tensor_dtype = torch.float16, inplace = inplace) + test(lib, handle, "cpu", c_shape, axis, input_shapes, tensor_dtype = torch.float32, inplace = inplace) destroy_handle(lib, handle) @@ -107,10 +108,10 @@ def test_cuda(lib, test_cases): device = DeviceEnum.DEVICE_CUDA handle = create_handle(lib, device) for c_shape, axis, input_shapes, inplace in test_cases: - test(lib, handle, "cuda", c_shape, axis, input_shapes, inplace=inplace) + test(lib, handle, "cuda", c_shape, axis, input_shapes, tensor_dtype = torch.float16, inplace = inplace) + test(lib, handle, "cuda", c_shape, axis, input_shapes, tensor_dtype = torch.float32, inplace = inplace) destroy_handle(lib, handle) - def test_bang(lib, test_cases): import torch_mlu @@ -124,6 +125,7 @@ def test_bang(lib, test_cases): if __name__ == "__main__": test_cases = [ + #output_tensor, axis, inputs_tensors, inplace ((6,), 0, [(2,), (4,)], Inplace.OUT_OF_PLACE), @@ -131,17 +133,14 @@ def test_bang(lib, test_cases): ((3, 6), 1, [(3, 2), (3, 4)], Inplace.OUT_OF_PLACE), ((3, 7), 1, [(3, 2), (3, 4), (3, 1)], Inplace.OUT_OF_PLACE), ((3, 3, 10), 2, [(3, 3, 4), (3, 3, 6)], Inplace.OUT_OF_PLACE), - ((4, 3, 6), 0, [(3, 3, 6), (1, 3, 6)], Inplace.OUT_OF_PLACE), ((2, 6, 3), 1, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE), ((2, 3, 6), 2, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE), - ((4, 3, 5, 6), 0, [(1, 3, 5, 6), (3, 3, 5, 6)], Inplace.OUT_OF_PLACE), ((2, 5, 5, 6), 1, [(2, 3, 5, 6), (2, 2, 5, 6)], Inplace.OUT_OF_PLACE), ((2, 3, 5, 6), 2, [(2, 3, 2, 6), (2, 3, 3, 6)], Inplace.OUT_OF_PLACE), ((2, 3, 5, 6), 3, [(2, 3, 5, 3), (2, 3, 5, 3)], Inplace.OUT_OF_PLACE), ((2, 3, 5, 15), 3, [(2, 3, 5, 3), (2, 3, 5, 3), (2, 3, 5, 9)], Inplace.OUT_OF_PLACE), - ((4, 2, 3, 4, 5), 0, [(1, 2, 3, 4, 5), (3, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE), ((2, 4, 3, 2, 5), 1, [(2, 2, 3, 2, 5), (2, 2, 3, 2, 5)], Inplace.OUT_OF_PLACE), ((1, 2, 4, 4, 5), 2, [(1, 2, 2, 4, 5), (1, 2, 2, 4, 5)], Inplace.OUT_OF_PLACE), @@ -149,30 +148,26 @@ def test_bang(lib, test_cases): ((1, 2, 3, 4, 5), 4, [(1, 2, 3, 4, 3), (1, 2, 3, 4, 2)], Inplace.OUT_OF_PLACE), ((4, 14, 3, 4, 5), 1, [(4, 3, 3, 4, 5), (4, 5, 3, 4, 5), (4, 6, 3, 4, 5)], Inplace.OUT_OF_PLACE), - ((6,), -1, [(2,), (4,)], Inplace.OUT_OF_PLACE), - ((6, 3), -2, [(2, 3), (4, 3)], Inplace.OUT_OF_PLACE), ((3, 6), -1, [(3, 2), (3, 4)], Inplace.OUT_OF_PLACE), ((3, 7), -1, [(3, 2), (3, 4), (3, 1)], Inplace.OUT_OF_PLACE), ((3, 3, 10), -1, [(3, 3, 4), (3, 3, 6)], Inplace.OUT_OF_PLACE), - ((4, 3, 6), -3, [(3, 3, 6), (1, 3, 6)], Inplace.OUT_OF_PLACE), ((2, 6, 3), -2, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE), ((2, 3, 6), -1, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE), - ((4, 3, 5, 6), -4, [(1, 3, 5, 6), (3, 3, 5, 6)], Inplace.OUT_OF_PLACE), ((2, 5, 5, 6), -3, [(2, 3, 5, 6), (2, 2, 5, 6)], Inplace.OUT_OF_PLACE), ((2, 3, 5, 6), -2, [(2, 3, 2, 6), (2, 3, 3, 6)], Inplace.OUT_OF_PLACE), ((2, 3, 5, 6), -1, [(2, 3, 5, 3), (2, 3, 5, 3)], Inplace.OUT_OF_PLACE), ((2, 3, 5, 15), -1, [(2, 3, 5, 3), (2, 3, 5, 3), (2, 3, 5, 9)], Inplace.OUT_OF_PLACE), - ((4, 2, 3, 4, 5), -5, [(1, 2, 3, 4, 5), (3, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE), ((2, 4, 3, 2, 5), -4, [(2, 2, 3, 2, 5), (2, 2, 3, 2, 5)], Inplace.OUT_OF_PLACE), ((1, 2, 4, 4, 5), -3, [(1, 2, 2, 4, 5), (1, 2, 2, 4, 5)], Inplace.OUT_OF_PLACE), ((1, 2, 3, 8, 5), -2, [(1, 2, 3, 4, 5), (1, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE), ((1, 2, 3, 4, 5), -1, [(1, 2, 3, 4, 3), (1, 2, 3, 4, 2)], Inplace.OUT_OF_PLACE), ((4, 14, 3, 4, 5), -4, [(4, 3, 3, 4, 5), (4, 5, 3, 4, 5), (4, 6, 3, 4, 5)], Inplace.OUT_OF_PLACE), + ] args = get_args() diff --git a/src/ops/concat/cpu/concat_cpu.cc b/src/ops/concat/cpu/concat_cpu.cc index 0f7dce72..6c9bd419 100644 --- a/src/ops/concat/cpu/concat_cpu.cc +++ b/src/ops/concat/cpu/concat_cpu.cc @@ -13,6 +13,10 @@ infiniopStatus_t cpuCreateConcatDescriptor( return STATUS_BAD_PARAM; } + if (!is_contiguous(y)) { + return STATUS_BAD_TENSOR_STRIDES; + } + int64_t ndim = y->ndim; if (axis >= ndim || axis < -ndim) { return STATUS_BAD_PARAM; @@ -29,6 +33,10 @@ infiniopStatus_t cpuCreateConcatDescriptor( for (size_t i = 0; i < num_inputs; ++i) { + if (!is_contiguous(x[i])) { + return STATUS_BAD_TENSOR_STRIDES; + } + if (x[i]->dt != y->dt) { return STATUS_BAD_TENSOR_DTYPE; } @@ -121,12 +129,11 @@ infiniopStatus_t cpuConcat(ConcatCpuDescriptor_t desc, void *y, void const **x, void *stream) { - - switch (desc->dtype.size) { - case sizeof(float): // FLOAT32 - return concatCompute(desc, reinterpret_cast(y), x); - // add other data.type - default: - return STATUS_SUCCESS; + if (desc->dtype == F16) { + return concatCompute(desc, reinterpret_cast(y), x); + } + if (desc->dtype == F32) { + return concatCompute(desc, reinterpret_cast(y), x); } + return STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/ops/concat/cuda/concat.cc b/src/ops/concat/cuda/concat.cc new file mode 100644 index 00000000..d99d167b --- /dev/null +++ b/src/ops/concat/cuda/concat.cc @@ -0,0 +1,73 @@ +#include "concat.cuh" +#include "../../../devices/cuda/common_cuda.h" +#include "../../utils.h" + +infiniopStatus_t cudaCreateConcatDescriptor(CudaHandle_t handle, + ConcatCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t *x, + uint64_t num_inputs, + int64_t axis){ + if (y == nullptr || x == nullptr || desc_ptr == nullptr || num_inputs == 0) { + return STATUS_BAD_PARAM; + } + + if (!is_contiguous(y)) { + return STATUS_BAD_TENSOR_STRIDES; + } + + int64_t ndim = y->ndim; + if (axis >= ndim || axis < -ndim) { + return STATUS_BAD_PARAM; + } + + if(axis < 0){ + axis = axis + ndim; + } + uint64_t total_size = 0; + + std::vector> input_shapes(num_inputs); + std::vector output_shape(y->shape, y->shape + ndim); + + for (size_t i = 0; i < num_inputs; ++i) { + + if (!is_contiguous(x[i])) { + return STATUS_BAD_TENSOR_STRIDES; + } + + if (x[i]->dt != y->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (x[i]->ndim != ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + for (size_t j = 0; j < ndim; ++j) { + if (j != axis && x[i]->shape[j] != y->shape[j]) { + return STATUS_BAD_TENSOR_SHAPE; + } + } + + input_shapes[i] = std::vector(x[i]->shape, x[i]->shape + ndim); + total_size += x[i]->shape[axis]; + } + + if (total_size != y->shape[axis]) { + return STATUS_BAD_TENSOR_SHAPE; + } + + *desc_ptr = new ConcatCudaDescriptor{ + DevNvGpu, + y->dt, + axis, + num_inputs, + input_shapes, + output_shape, + }; + + return STATUS_SUCCESS; +} + +infiniopStatus_t cudaDestroyConcatDescriptor(ConcatCudaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} \ No newline at end of file diff --git a/src/ops/concat/cuda/concat.cu b/src/ops/concat/cuda/concat.cu new file mode 100644 index 00000000..2c3d8ad6 --- /dev/null +++ b/src/ops/concat/cuda/concat.cu @@ -0,0 +1,86 @@ +#include "../../../devices/cuda/common_cuda.h" +#include "../../utils.h" +#include "concat.cuh" + +// Kernel function to perform concatenation on GPU +template +__global__ void concatKernel(const T* x, T* y, + size_t inSize, + size_t localBlockOffset, + size_t innerOffset, + size_t blockOffset) { + size_t iOffset = blockIdx.x * blockDim.x + threadIdx.x; + if (iOffset < inSize) { + size_t oOffset = (iOffset % localBlockOffset) + innerOffset + + (iOffset / localBlockOffset) * blockOffset; + y[oOffset] = x[iOffset]; + } +} + +template +infiniopStatus_t concatCompute(ConcatCudaDescriptor_t& desc, + T* y, + void const** x, + cudaStream_t stream) { + int64_t axis = desc->axis; + uint64_t num_inputs = desc->num_inputs; + const std::vector>& input_shapes = desc->input_shapes; + const std::vector& output_shape = desc->output_shape; + + size_t blockOffsetInner = 1; + for (size_t i = output_shape.size() - 1; i > axis; --i) { + blockOffsetInner *= output_shape[i]; + } + size_t blockOffset = output_shape[axis] * blockOffsetInner; + +#pragma unroll + for (size_t i = 0; i < num_inputs; ++i) { + const std::vector& input_shape = input_shapes[i]; + + size_t dimOffset = 0; + for (size_t j = 0; j < i; ++j) { + dimOffset += input_shapes[j][axis]; + } + + size_t localBlockOffset = 1; + for (size_t j = input_shape.size() - 1; j >= axis && j != static_cast(-1); --j) { + localBlockOffset *= input_shape[j]; + } + + size_t innerOffset = blockOffsetInner * dimOffset; + size_t inSize = 1; + for (auto dim : input_shape) { + inSize *= dim; + } + + T* input_data = static_cast(const_cast(x[i])); + + // Launch CUDA kernel + int threads = 256; + int blocks = (inSize + threads - 1) / threads; + concatKernel<<>>(input_data, y, inSize, localBlockOffset, innerOffset, blockOffset); + + // Check for CUDA errors + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + return STATUS_EXECUTION_FAILED; + } + } + + return STATUS_SUCCESS; +} + +infiniopStatus_t cudaConcat(ConcatCudaDescriptor_t desc, + void* y, + void const** x, + void* stream) { + cudaStream_t cudaStream = reinterpret_cast(stream); + + if (desc->dtype == F16) { + return concatCompute(desc, reinterpret_cast(y), x, cudaStream); + } + if (desc->dtype == F32) { + return concatCompute(desc, reinterpret_cast(y), x, cudaStream); + } + return STATUS_BAD_TENSOR_DTYPE; +} \ No newline at end of file diff --git a/src/ops/concat/cuda/concat.cuh b/src/ops/concat/cuda/concat.cuh new file mode 100644 index 00000000..9eeaf06f --- /dev/null +++ b/src/ops/concat/cuda/concat.cuh @@ -0,0 +1,36 @@ +#ifndef __CUDA_CONCAT_H__ +#define __CUDA_CONCAT_H__ + +#include "../../../devices/cuda/common_cuda.h" +#include "../../../devices/cuda/cuda_handle.h" +#include "operators.h" +#include +#include +#include + +struct ConcatCudaDescriptor { + Device device; + DT dtype; + int64_t axis; + uint64_t num_inputs; + std::vector> input_shapes; + std::vector output_shape; +}; + +typedef struct ConcatCudaDescriptor *ConcatCudaDescriptor_t; + +infiniopStatus_t cudaCreateConcatDescriptor(CudaHandle_t handle, + ConcatCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t *x, + uint64_t nums_input, + int64_t axis); + +infiniopStatus_t cudaConcat(ConcatCudaDescriptor_t desc, + void *y, + void const **x, + void *stream); + +infiniopStatus_t cudaDestroyConcatDescriptor(ConcatCudaDescriptor_t desc); + +#endif \ No newline at end of file