1- from ctypes import POINTER , Structure , c_int32 , c_void_p , c_uint64
1+ from ctypes import POINTER , Structure , c_int32 , c_void_p , c_uint64 , c_int64
22import ctypes
33import sys
44import os
2222
2323class Inplace (Enum ):
2424 OUT_OF_PLACE = auto ()
25- # 对于 concat 算子,通常不支持 in-place 操作,保留 OUT_OF_PLACE
2625
2726class ConcatDescriptor (Structure ):
2827 _fields_ = [("device" , c_int32 ),]
@@ -32,7 +31,6 @@ class ConcatDescriptor(Structure):
3231
3332
3433def concat_py (* tensors , dim = 0 ):
35- """使用 PyTorch 进行拼接的辅助函数"""
3634 return torch .cat (tensors , dim = dim )
3735
3836
@@ -63,7 +61,6 @@ def test(
6361 if inplace == Inplace .OUT_OF_PLACE :
6462 c = torch .zeros (c_shape , dtype = tensor_dtype ).to (torch_device )
6563 else :
66- # concat通常不支持 in-place 操作,简化为 OUT_OF_PLACE
6764 c = torch .zeros (c_shape , dtype = tensor_dtype ).to (torch_device )
6865
6966 ans = concat_py (* inputs , dim = axis )
@@ -74,10 +71,8 @@ def test(
7471 input_tensors = [to_tensor (t , lib ) for t in inputs ]
7572 c_tensor = to_tensor (c , lib ) if inplace == Inplace .OUT_OF_PLACE else to_tensor (c , lib )
7673
77- # 创建 Concat 描述符
7874 descriptor = infiniopConcatDescriptor_t ()
79-
80- # 准备输入描述符数组
75+
8176 num_inputs = len (input_tensors )
8277 input_desc_array_type = infiniopTensorDescriptor_t * num_inputs
8378 input_desc_array = input_desc_array_type (* [t .descriptor for t in input_tensors ])
@@ -89,7 +84,7 @@ def test(
8984 c_tensor .descriptor ,
9085 input_desc_array ,
9186 c_uint64 (num_inputs ),
92- c_uint64 (axis ),
87+ c_int64 (axis ),
9388 )
9489 )
9590
@@ -139,34 +134,57 @@ def test_bang(lib, test_cases):
139134
140135
141136if __name__ == "__main__" :
142- # 定义测试用例
137+
143138 test_cases = [
144- # (output_shape, axis, input_shapes, inplace)
145-
146- # 一维张量拼接
147- ((6 ,), 0 , [(2 ,), (4 ,)], Inplace .OUT_OF_PLACE ), # 沿第0轴拼接
148- # 二维张量拼接
149- ((6 , 3 ), 0 , [(2 , 3 ), (4 , 3 )], Inplace .OUT_OF_PLACE ), # 沿第0轴拼接
150- ((3 , 6 ), 1 , [(3 , 2 ), (3 , 4 )], Inplace .OUT_OF_PLACE ), # 沿第1轴拼接
151- ((3 , 7 ), 1 , [(3 , 2 ), (3 , 4 ), (3 , 1 )], Inplace .OUT_OF_PLACE ), # 沿第1轴拼接
152- ((3 , 3 , 10 ), 2 , [(3 , 3 , 4 ), (3 , 3 , 6 )], Inplace .OUT_OF_PLACE ), # 沿第2轴拼接
153- # 三维张量拼接
154- ((4 , 3 , 6 ), 0 , [(3 , 3 , 6 ), (1 , 3 , 6 )], Inplace .OUT_OF_PLACE ), # 沿第0轴拼接
155- ((2 , 6 , 3 ), 1 , [(2 , 3 , 3 ), (2 , 3 , 3 )], Inplace .OUT_OF_PLACE ), # 沿第1轴拼接
156- ((2 , 3 , 6 ), 2 , [(2 , 3 , 3 ), (2 , 3 , 3 )], Inplace .OUT_OF_PLACE ), # 沿第2轴拼接
157- # 四维张量拼接
158- ((4 , 3 , 5 , 6 ), 0 , [(1 , 3 , 5 , 6 ), (3 , 3 , 5 , 6 )], Inplace .OUT_OF_PLACE ), # 沿第0轴拼接
159- ((2 , 5 , 5 , 6 ), 1 , [(2 , 3 , 5 , 6 ), (2 , 2 , 5 , 6 )], Inplace .OUT_OF_PLACE ), # 沿第1轴拼接
160- ((2 , 3 , 5 , 6 ), 2 , [(2 , 3 , 2 , 6 ), (2 , 3 , 3 , 6 )], Inplace .OUT_OF_PLACE ), # 沿第2轴拼接
161- ((2 , 3 , 5 , 6 ), 3 , [(2 , 3 , 5 , 3 ), (2 , 3 , 5 , 3 )], Inplace .OUT_OF_PLACE ), # 沿第3轴拼接
162- ((2 , 3 , 5 , 15 ), 3 , [(2 , 3 , 5 , 3 ), (2 , 3 , 5 , 3 ), (2 , 3 , 5 , 9 )], Inplace .OUT_OF_PLACE ), # 沿第3轴拼接
163- # 五维张量拼接
164- ((4 , 2 , 3 , 4 , 5 ), 0 , [(1 , 2 , 3 , 4 , 5 ), (3 , 2 , 3 , 4 , 5 )], Inplace .OUT_OF_PLACE ), # 沿第0轴拼接
165- ((2 , 4 , 3 , 2 , 5 ), 1 , [(2 , 2 , 3 , 2 , 5 ), (2 , 2 , 3 , 2 , 5 )], Inplace .OUT_OF_PLACE ), # 沿第1轴拼接
166- ((1 , 2 , 4 , 4 , 5 ), 2 , [(1 , 2 , 2 , 4 , 5 ), (1 , 2 , 2 , 4 , 5 )], Inplace .OUT_OF_PLACE ), # 沿第2轴拼接
167- ((1 , 2 , 3 , 8 , 5 ), 3 , [(1 , 2 , 3 , 4 , 5 ), (1 , 2 , 3 , 4 , 5 )], Inplace .OUT_OF_PLACE ), # 沿第3轴拼接
168- ((1 , 2 , 3 , 4 , 5 ), 4 , [(1 , 2 , 3 , 4 , 3 ), (1 , 2 , 3 , 4 , 2 )], Inplace .OUT_OF_PLACE ), # 沿第4轴拼接
169- ((4 , 14 , 3 , 4 , 5 ), 1 , [(4 , 3 , 3 , 4 , 5 ), (4 , 5 , 3 , 4 , 5 ), (4 , 6 , 3 , 4 , 5 )], Inplace .OUT_OF_PLACE ), # 沿第1轴拼接
139+
140+ ((6 ,), 0 , [(2 ,), (4 ,)], Inplace .OUT_OF_PLACE ),
141+
142+ ((6 , 3 ), 0 , [(2 , 3 ), (4 , 3 )], Inplace .OUT_OF_PLACE ),
143+ ((3 , 6 ), 1 , [(3 , 2 ), (3 , 4 )], Inplace .OUT_OF_PLACE ),
144+ ((3 , 7 ), 1 , [(3 , 2 ), (3 , 4 ), (3 , 1 )], Inplace .OUT_OF_PLACE ),
145+ ((3 , 3 , 10 ), 2 , [(3 , 3 , 4 ), (3 , 3 , 6 )], Inplace .OUT_OF_PLACE ),
146+
147+ ((4 , 3 , 6 ), 0 , [(3 , 3 , 6 ), (1 , 3 , 6 )], Inplace .OUT_OF_PLACE ),
148+ ((2 , 6 , 3 ), 1 , [(2 , 3 , 3 ), (2 , 3 , 3 )], Inplace .OUT_OF_PLACE ),
149+ ((2 , 3 , 6 ), 2 , [(2 , 3 , 3 ), (2 , 3 , 3 )], Inplace .OUT_OF_PLACE ),
150+
151+ ((4 , 3 , 5 , 6 ), 0 , [(1 , 3 , 5 , 6 ), (3 , 3 , 5 , 6 )], Inplace .OUT_OF_PLACE ),
152+ ((2 , 5 , 5 , 6 ), 1 , [(2 , 3 , 5 , 6 ), (2 , 2 , 5 , 6 )], Inplace .OUT_OF_PLACE ),
153+ ((2 , 3 , 5 , 6 ), 2 , [(2 , 3 , 2 , 6 ), (2 , 3 , 3 , 6 )], Inplace .OUT_OF_PLACE ),
154+ ((2 , 3 , 5 , 6 ), 3 , [(2 , 3 , 5 , 3 ), (2 , 3 , 5 , 3 )], Inplace .OUT_OF_PLACE ),
155+ ((2 , 3 , 5 , 15 ), 3 , [(2 , 3 , 5 , 3 ), (2 , 3 , 5 , 3 ), (2 , 3 , 5 , 9 )], Inplace .OUT_OF_PLACE ),
156+
157+ ((4 , 2 , 3 , 4 , 5 ), 0 , [(1 , 2 , 3 , 4 , 5 ), (3 , 2 , 3 , 4 , 5 )], Inplace .OUT_OF_PLACE ),
158+ ((2 , 4 , 3 , 2 , 5 ), 1 , [(2 , 2 , 3 , 2 , 5 ), (2 , 2 , 3 , 2 , 5 )], Inplace .OUT_OF_PLACE ),
159+ ((1 , 2 , 4 , 4 , 5 ), 2 , [(1 , 2 , 2 , 4 , 5 ), (1 , 2 , 2 , 4 , 5 )], Inplace .OUT_OF_PLACE ),
160+ ((1 , 2 , 3 , 8 , 5 ), 3 , [(1 , 2 , 3 , 4 , 5 ), (1 , 2 , 3 , 4 , 5 )], Inplace .OUT_OF_PLACE ),
161+ ((1 , 2 , 3 , 4 , 5 ), 4 , [(1 , 2 , 3 , 4 , 3 ), (1 , 2 , 3 , 4 , 2 )], Inplace .OUT_OF_PLACE ),
162+ ((4 , 14 , 3 , 4 , 5 ), 1 , [(4 , 3 , 3 , 4 , 5 ), (4 , 5 , 3 , 4 , 5 ), (4 , 6 , 3 , 4 , 5 )], Inplace .OUT_OF_PLACE ),
163+
164+
165+ ((6 ,), - 1 , [(2 ,), (4 ,)], Inplace .OUT_OF_PLACE ),
166+
167+ ((6 , 3 ), - 2 , [(2 , 3 ), (4 , 3 )], Inplace .OUT_OF_PLACE ),
168+ ((3 , 6 ), - 1 , [(3 , 2 ), (3 , 4 )], Inplace .OUT_OF_PLACE ),
169+ ((3 , 7 ), - 1 , [(3 , 2 ), (3 , 4 ), (3 , 1 )], Inplace .OUT_OF_PLACE ),
170+ ((3 , 3 , 10 ), - 1 , [(3 , 3 , 4 ), (3 , 3 , 6 )], Inplace .OUT_OF_PLACE ),
171+
172+ ((4 , 3 , 6 ), - 3 , [(3 , 3 , 6 ), (1 , 3 , 6 )], Inplace .OUT_OF_PLACE ),
173+ ((2 , 6 , 3 ), - 2 , [(2 , 3 , 3 ), (2 , 3 , 3 )], Inplace .OUT_OF_PLACE ),
174+ ((2 , 3 , 6 ), - 1 , [(2 , 3 , 3 ), (2 , 3 , 3 )], Inplace .OUT_OF_PLACE ),
175+
176+ ((4 , 3 , 5 , 6 ), - 4 , [(1 , 3 , 5 , 6 ), (3 , 3 , 5 , 6 )], Inplace .OUT_OF_PLACE ),
177+ ((2 , 5 , 5 , 6 ), - 3 , [(2 , 3 , 5 , 6 ), (2 , 2 , 5 , 6 )], Inplace .OUT_OF_PLACE ),
178+ ((2 , 3 , 5 , 6 ), - 2 , [(2 , 3 , 2 , 6 ), (2 , 3 , 3 , 6 )], Inplace .OUT_OF_PLACE ),
179+ ((2 , 3 , 5 , 6 ), - 1 , [(2 , 3 , 5 , 3 ), (2 , 3 , 5 , 3 )], Inplace .OUT_OF_PLACE ),
180+ ((2 , 3 , 5 , 15 ), - 1 , [(2 , 3 , 5 , 3 ), (2 , 3 , 5 , 3 ), (2 , 3 , 5 , 9 )], Inplace .OUT_OF_PLACE ),
181+
182+ ((4 , 2 , 3 , 4 , 5 ), - 5 , [(1 , 2 , 3 , 4 , 5 ), (3 , 2 , 3 , 4 , 5 )], Inplace .OUT_OF_PLACE ),
183+ ((2 , 4 , 3 , 2 , 5 ), - 4 , [(2 , 2 , 3 , 2 , 5 ), (2 , 2 , 3 , 2 , 5 )], Inplace .OUT_OF_PLACE ),
184+ ((1 , 2 , 4 , 4 , 5 ), - 3 , [(1 , 2 , 2 , 4 , 5 ), (1 , 2 , 2 , 4 , 5 )], Inplace .OUT_OF_PLACE ),
185+ ((1 , 2 , 3 , 8 , 5 ), - 2 , [(1 , 2 , 3 , 4 , 5 ), (1 , 2 , 3 , 4 , 5 )], Inplace .OUT_OF_PLACE ),
186+ ((1 , 2 , 3 , 4 , 5 ), - 1 , [(1 , 2 , 3 , 4 , 3 ), (1 , 2 , 3 , 4 , 2 )], Inplace .OUT_OF_PLACE ),
187+ ((4 , 14 , 3 , 4 , 5 ), - 4 , [(4 , 3 , 3 , 4 , 5 ), (4 , 5 , 3 , 4 , 5 ), (4 , 6 , 3 , 4 , 5 )], Inplace .OUT_OF_PLACE ),
170188 ]
171189
172190 args = get_args ()
@@ -179,7 +197,7 @@ def test_bang(lib, test_cases):
179197 infiniopTensorDescriptor_t ,
180198 POINTER (infiniopTensorDescriptor_t ),
181199 c_uint64 , # nums_input
182- c_uint64 , # axis
200+ c_int64 , # axis
183201 ]
184202
185203 lib .infiniopConcat .restype = c_int32
0 commit comments