33import sys
44import os
55
6- # 调整路径以导入 operatorspy 模块
76sys .path .insert (0 , os .path .abspath (os .path .join (os .path .dirname (__file__ ), ".." , ".." )))
87from operatorspy import (
98 open_lib ,
2423class Inplace (Enum ):
2524 OUT_OF_PLACE = auto ()
2625 # 对于 concat 算子,通常不支持 in-place 操作,因此这里只保留 OUT_OF_PLACE
27- # 你可以根据实际需求扩展其他选项
28- # INPLACE_A = auto()
29- # INPLACE_B = auto()
30-
3126
3227class ConcatDescriptor (Structure ):
3328 _fields_ = [("device" , c_int32 ),]
@@ -58,28 +53,24 @@ def test(
5853 f"Testing Concat on { torch_device } with output_shape:{ c_shape } , input_shapes:{ input_shapes } , axis:{ axis } , dtype:{ tensor_dtype } , inplace: { inplace .name } "
5954 )
6055
61- # 创建输入张量
6256 inputs = [torch .rand (shape , dtype = tensor_dtype ).to (torch_device ) for shape in input_shapes ]
6357
6458 for idx , tensor in enumerate (inputs ):
6559 print (f"Input { idx } :" )
6660 print (tensor )
6761 print ("-" * 50 )
6862
69- # 创建输出张量
7063 if inplace == Inplace .OUT_OF_PLACE :
7164 c = torch .zeros (c_shape , dtype = tensor_dtype ).to (torch_device )
7265 else :
7366 # 对于 concat,通常不支持 in-place 操作,因此这里简化为 OUT_OF_PLACE
7467 c = torch .zeros (c_shape , dtype = tensor_dtype ).to (torch_device )
7568
76- # 使用 PyTorch 进行拼接,作为参考答案
7769 ans = concat_py (* inputs , dim = axis )
7870
7971 print ("ans:" ,ans )
8072 print ("-" * 50 )
8173
82- # 将张量转换为 infiniop 所需的格式
8374 input_tensors = [to_tensor (t , lib ) for t in inputs ]
8475 c_tensor = to_tensor (c , lib ) if inplace == Inplace .OUT_OF_PLACE else to_tensor (c , lib )
8576
@@ -91,22 +82,17 @@ def test(
9182 input_desc_array_type = infiniopTensorDescriptor_t * num_inputs
9283 input_desc_array = input_desc_array_type (* [t .descriptor for t in input_tensors ])
9384
94- # 创建描述符
9585 check_error (
9686 lib .infiniopCreateConcatDescriptor (
9787 handle ,
9888 ctypes .byref (descriptor ),
99- c_tensor .descriptor , # 使用 c_tensor 的描述符
100- input_desc_array , # 输入张量描述符数组
89+ c_tensor .descriptor ,
90+ input_desc_array ,
10191 c_uint64 (num_inputs ),
10292 c_uint64 (axis ),
10393 )
10494 )
10595
106- print ("c1:" ,c )
107- print ("-" * 50 )
108-
109- # 执行拼接操作
11096 input_data_ptrs = (c_void_p * num_inputs )(* [t .data for t in input_tensors ])
11197 check_error (
11298 lib .infiniopConcat (
@@ -121,9 +107,8 @@ def test(
121107 print ("-" * 50 )
122108
123109 # 验证结果
124- assert torch .allclose (c , ans , atol = 0 , rtol = 1e-5 ), "Concat result does not match PyTorch's result."
110+ assert torch .allclose (c , ans , atol = 0 , rtol = 0 ), "Concat result does not match PyTorch's result."
125111
126- # 销毁描述符
127112 check_error (lib .infiniopDestroyConcatDescriptor (descriptor ))
128113
129114
@@ -157,50 +142,59 @@ def test_bang(lib, test_cases):
157142 # 定义测试用例
158143 test_cases = [
159144 # (output_shape, axis, input_shapes, inplace)
160-
161- ((6 , 3 ), 0 , [(2 , 3 ), (4 , 3 )], Inplace .OUT_OF_PLACE ),
162- # ((3, 6), 1, [(3, 2), (3, 4)], Inplace.OUT_OF_PLACE),
163- # ((3, 7), 1, [(3, 2), (3, 4), (3,1)], Inplace.OUT_OF_PLACE),
164- # ((3, 3, 10), 2, [(3, 3, 4), (3, 3, 6)], Inplace.OUT_OF_PLACE),
165- # ((1, 1), 0, [(1, 1)], Inplace.OUT_OF_PLACE),
166- # ((4, 5, 6), 0, [(1, 5, 6), (3, 5, 6)], Inplace.OUT_OF_PLACE),
167- # ((2, 3, 6), 2, [(2, 3, 2), (2, 3, 4)], Inplace.OUT_OF_PLACE),
168-
169- # 添加更多测试用例以覆盖不同的维度和拼接轴
170- # ((2, 10, 3), 1, [(2, 5, 3), (2, 2, 3),(2,3,3)], Inplace.OUT_OF_PLACE), # 拼接沿第二维
145+
146+ # 一维张量拼接
147+ ((6 ,), 0 , [(2 ,), (4 ,)], Inplace .OUT_OF_PLACE ), # 沿第0轴拼接
148+ # 二维张量拼接
149+ ((6 , 3 ), 0 , [(2 , 3 ), (4 , 3 )], Inplace .OUT_OF_PLACE ), # 沿第0轴拼接
150+ ((3 , 6 ), 1 , [(3 , 2 ), (3 , 4 )], Inplace .OUT_OF_PLACE ), # 沿第1轴拼接
151+ ((3 , 7 ), 1 , [(3 , 2 ), (3 , 4 ), (3 , 1 )], Inplace .OUT_OF_PLACE ), # 沿第1轴拼接
152+ ((3 , 3 , 10 ), 2 , [(3 , 3 , 4 ), (3 , 3 , 6 )], Inplace .OUT_OF_PLACE ), # 沿第2轴拼接
153+ # 三维张量拼接
154+ ((4 , 3 , 6 ), 0 , [(3 , 3 , 6 ), (1 , 3 , 6 )], Inplace .OUT_OF_PLACE ), # 沿第0轴拼接
155+ ((2 , 6 , 3 ), 1 , [(2 , 3 , 3 ), (2 , 3 , 3 )], Inplace .OUT_OF_PLACE ), # 沿第1轴拼接
156+ ((2 , 3 , 6 ), 2 , [(2 , 3 , 3 ), (2 , 3 , 3 )], Inplace .OUT_OF_PLACE ), # 沿第2轴拼接
157+ # 四维张量拼接
158+ ((4 , 3 , 5 , 6 ), 0 , [(1 , 3 , 5 , 6 ), (3 , 3 , 5 , 6 )], Inplace .OUT_OF_PLACE ), # 沿第0轴拼接
159+ ((2 , 5 , 5 , 6 ), 1 , [(2 , 3 , 5 , 6 ), (2 , 2 , 5 , 6 )], Inplace .OUT_OF_PLACE ), # 沿第1轴拼接
160+ ((2 , 3 , 5 , 6 ), 2 , [(2 , 3 , 2 , 6 ), (2 , 3 , 3 , 6 )], Inplace .OUT_OF_PLACE ), # 沿第2轴拼接
161+ ((2 , 3 , 5 , 6 ), 3 , [(2 , 3 , 5 , 3 ), (2 , 3 , 5 , 3 )], Inplace .OUT_OF_PLACE ), # 沿第3轴拼接
162+ ((2 , 3 , 5 , 15 ), 3 , [(2 , 3 , 5 , 3 ), (2 , 3 , 5 , 3 ), (2 , 3 , 5 , 9 )], Inplace .OUT_OF_PLACE ), # 沿第3轴拼接
163+ # 五维张量拼接
164+ ((4 , 2 , 3 , 4 , 5 ), 0 , [(1 , 2 , 3 , 4 , 5 ), (3 , 2 , 3 , 4 , 5 )], Inplace .OUT_OF_PLACE ), # 沿第0轴拼接
165+ ((2 , 4 , 3 , 2 , 5 ), 1 , [(2 , 2 , 3 , 2 , 5 ), (2 , 2 , 3 , 2 , 5 )], Inplace .OUT_OF_PLACE ), # 沿第1轴拼接
166+ ((1 , 2 , 4 , 4 , 5 ), 2 , [(1 , 2 , 2 , 4 , 5 ), (1 , 2 , 2 , 4 , 5 )], Inplace .OUT_OF_PLACE ), # 沿第2轴拼接
167+ ((1 , 2 , 3 , 8 , 5 ), 3 , [(1 , 2 , 3 , 4 , 5 ), (1 , 2 , 3 , 4 , 5 )], Inplace .OUT_OF_PLACE ), # 沿第3轴拼接
168+ ((1 , 2 , 3 , 4 , 5 ), 4 , [(1 , 2 , 3 , 4 , 3 ), (1 , 2 , 3 , 4 , 2 )], Inplace .OUT_OF_PLACE ), # 沿第4轴拼接
169+ ((4 , 14 , 3 , 4 , 5 ), 1 , [(4 , 3 , 3 , 4 , 5 ), (4 , 5 , 3 , 4 , 5 ), (4 , 6 , 3 , 4 , 5 )], Inplace .OUT_OF_PLACE ), # 沿第1轴拼接
171170 ]
172171
173172 args = get_args ()
174173 lib = open_lib ()
175174
176- # 绑定 C++ 函数
177- # 创建 Concat 描述符
178175 lib .infiniopCreateConcatDescriptor .restype = c_int32
179176 lib .infiniopCreateConcatDescriptor .argtypes = [
180177 infiniopHandle_t ,
181178 POINTER (infiniopConcatDescriptor_t ),
182- infiniopTensorDescriptor_t , # 输出张量描述符
183- POINTER (infiniopTensorDescriptor_t ), # 输入张量描述符数组
184- c_uint64 , # 输入张量数量
185- c_uint64 , # 拼接轴
179+ infiniopTensorDescriptor_t ,
180+ POINTER (infiniopTensorDescriptor_t ),
181+ c_uint64 , # nums_input
182+ c_uint64 , # axis
186183 ]
187184
188- # 执行 Concat
189185 lib .infiniopConcat .restype = c_int32
190186 lib .infiniopConcat .argtypes = [
191187 infiniopConcatDescriptor_t ,
192- c_void_p , # 输出数据指针
193- POINTER (c_void_p ), # 输入数据指针数组
194- c_void_p , # 流(假设为 NULL)
188+ c_void_p ,
189+ POINTER (c_void_p ),
190+ c_void_p ,
195191 ]
196192
197- # 销毁 Concat 描述符
198193 lib .infiniopDestroyConcatDescriptor .restype = c_int32
199194 lib .infiniopDestroyConcatDescriptor .argtypes = [
200195 infiniopConcatDescriptor_t ,
201196 ]
202197
203- # 根据命令行参数执行测试
204198 if args .cpu :
205199 test_cpu (lib , test_cases )
206200 if args .cuda :
0 commit comments