@@ -8,22 +8,25 @@ infiniopStatus_t cpuCreateConcatDescriptor(
88 infiniopTensorDescriptor_t y,
99 infiniopTensorDescriptor_t *x,
1010 uint64_t num_inputs,
11- uint64_t axis) {
11+ int64_t axis) {
1212 if (y == nullptr || x == nullptr || desc_ptr == nullptr || num_inputs == 0 ) {
1313 return STATUS_BAD_PARAM;
1414 }
1515
16- uint64_t ndim = y->ndim ; // 输出张量维度
17- if (axis >= ndim) {
18- return STATUS_BAD_TENSOR_SHAPE;
16+ int64_t ndim = y->ndim ;
17+ if (axis >= ndim || axis < -ndim) {
18+ return STATUS_BAD_PARAM;
19+ }
20+
21+ if (axis < 0 ){
22+ axis = axis + ndim;
1923 }
2024
21- uint64_t total_size = 0 ; // 拼接轴的总大小
22- std::vector<std::vector<uint64_t >> input_shapes (num_inputs); // 输入张量形状
25+ uint64_t total_size = 0 ;
26+ std::vector<std::vector<uint64_t >> input_shapes (num_inputs);
2327
2428 std::vector<uint64_t > output_shape (y->shape , y->shape + ndim);
2529
26- // 验证输入张量的形状和步长
2730 for (size_t i = 0 ; i < num_inputs; ++i) {
2831
2932 if (x[i]->dt != y->dt ) {
@@ -41,12 +44,9 @@ infiniopStatus_t cpuCreateConcatDescriptor(
4144 }
4245
4346 input_shapes[i] = std::vector<uint64_t >(x[i]->shape , x[i]->shape + ndim);
44-
45- // 累加拼接轴的总大小
4647 total_size += x[i]->shape [axis];
4748 }
4849
49- // 验证输出张量形状是否匹配
5050 if (total_size != y->shape [axis]) {
5151 return STATUS_BAD_TENSOR_SHAPE;
5252 }
@@ -72,8 +72,7 @@ template <typename T>
7272infiniopStatus_t concatCompute (const ConcatCpuDescriptor_t& desc,
7373 T* y,
7474 void const ** x) {
75- // 获取描述符中的信息
76- uint64_t axis = desc->axis ;
75+ int64_t axis = desc->axis ;
7776 uint64_t num_inputs = desc->num_inputs ;
7877 const std::vector<std::vector<uint64_t >>& input_shapes = desc->input_shapes ;
7978 const std::vector<uint64_t >& output_shape = desc->output_shape ;
@@ -84,7 +83,6 @@ infiniopStatus_t concatCompute(const ConcatCpuDescriptor_t& desc,
8483 }
8584 size_t blockOffset = output_shape[axis] * blockOffsetInner;
8685
87- // concat
8886 for (size_t i = 0 ; i < num_inputs; ++i) {
8987 const std::vector<uint64_t >& input_shape = input_shapes[i];
9088
@@ -104,7 +102,6 @@ infiniopStatus_t concatCompute(const ConcatCpuDescriptor_t& desc,
104102 inSize *= dim;
105103 }
106104
107- // 获取输入和输出的数据指针
108105 T* input_data = static_cast <T*>(const_cast <void *>(x[i]));
109106
110107 #pragma omp parallel for
@@ -120,16 +117,15 @@ infiniopStatus_t concatCompute(const ConcatCpuDescriptor_t& desc,
120117 return STATUS_SUCCESS;
121118}
122119
123- // 主拼接函数
124120infiniopStatus_t cpuConcat (ConcatCpuDescriptor_t desc,
125121 void *y,
126122 void const **x,
127123 void *stream) {
128- // 根据数据类型调用相应的模板实例
124+
129125 switch (desc->dtype .size ) {
130126 case sizeof (float ): // FLOAT32
131127 return concatCompute<float >(desc, reinterpret_cast <float *>(y), x);
132- // 可以根据需要添加更多数据类型
128+ // add other data.type
133129 default :
134130 return STATUS_SUCCESS;
135131 }
0 commit comments