Skip to content

Commit f00e101

Browse files
committed
add support for negative axis in concat opreation
1 parent cd39041 commit f00e101

File tree

1 file changed

+54
-36
lines changed

1 file changed

+54
-36
lines changed

operatorspy/tests/concat.py

Lines changed: 54 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64
1+
from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64, c_int64
22
import ctypes
33
import sys
44
import os
@@ -22,7 +22,6 @@
2222

2323
class Inplace(Enum):
2424
OUT_OF_PLACE = auto()
25-
# 对于 concat 算子,通常不支持 in-place 操作,保留 OUT_OF_PLACE
2625

2726
class ConcatDescriptor(Structure):
2827
_fields_ = [("device", c_int32),]
@@ -32,7 +31,6 @@ class ConcatDescriptor(Structure):
3231

3332

3433
def concat_py(*tensors, dim=0):
35-
"""使用 PyTorch 进行拼接的辅助函数"""
3634
return torch.cat(tensors, dim=dim)
3735

3836

@@ -63,7 +61,6 @@ def test(
6361
if inplace == Inplace.OUT_OF_PLACE:
6462
c = torch.zeros(c_shape, dtype=tensor_dtype).to(torch_device)
6563
else:
66-
# concat通常不支持 in-place 操作,简化为 OUT_OF_PLACE
6764
c = torch.zeros(c_shape, dtype=tensor_dtype).to(torch_device)
6865

6966
ans = concat_py(*inputs, dim=axis)
@@ -74,10 +71,8 @@ def test(
7471
input_tensors = [to_tensor(t, lib) for t in inputs]
7572
c_tensor = to_tensor(c, lib) if inplace == Inplace.OUT_OF_PLACE else to_tensor(c, lib)
7673

77-
# 创建 Concat 描述符
7874
descriptor = infiniopConcatDescriptor_t()
79-
80-
# 准备输入描述符数组
75+
8176
num_inputs = len(input_tensors)
8277
input_desc_array_type = infiniopTensorDescriptor_t * num_inputs
8378
input_desc_array = input_desc_array_type(*[t.descriptor for t in input_tensors])
@@ -89,7 +84,7 @@ def test(
8984
c_tensor.descriptor,
9085
input_desc_array,
9186
c_uint64(num_inputs),
92-
c_uint64(axis),
87+
c_int64(axis),
9388
)
9489
)
9590

@@ -139,34 +134,57 @@ def test_bang(lib, test_cases):
139134

140135

141136
if __name__ == "__main__":
142-
# 定义测试用例
137+
143138
test_cases = [
144-
# (output_shape, axis, input_shapes, inplace)
145-
146-
# 一维张量拼接
147-
((6,), 0, [(2,), (4,)], Inplace.OUT_OF_PLACE), # 沿第0轴拼接
148-
# 二维张量拼接
149-
((6, 3), 0, [(2, 3), (4, 3)], Inplace.OUT_OF_PLACE), # 沿第0轴拼接
150-
((3, 6), 1, [(3, 2), (3, 4)], Inplace.OUT_OF_PLACE), # 沿第1轴拼接
151-
((3, 7), 1, [(3, 2), (3, 4), (3, 1)], Inplace.OUT_OF_PLACE), # 沿第1轴拼接
152-
((3, 3, 10), 2, [(3, 3, 4), (3, 3, 6)], Inplace.OUT_OF_PLACE), # 沿第2轴拼接
153-
# 三维张量拼接
154-
((4, 3, 6), 0, [(3, 3, 6), (1, 3, 6)], Inplace.OUT_OF_PLACE), # 沿第0轴拼接
155-
((2, 6, 3), 1, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE), # 沿第1轴拼接
156-
((2, 3, 6), 2, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE), # 沿第2轴拼接
157-
# 四维张量拼接
158-
((4, 3, 5, 6), 0, [(1, 3, 5, 6), (3, 3, 5, 6)], Inplace.OUT_OF_PLACE), # 沿第0轴拼接
159-
((2, 5, 5, 6), 1, [(2, 3, 5, 6), (2, 2, 5, 6)], Inplace.OUT_OF_PLACE), # 沿第1轴拼接
160-
((2, 3, 5, 6), 2, [(2, 3, 2, 6), (2, 3, 3, 6)], Inplace.OUT_OF_PLACE), # 沿第2轴拼接
161-
((2, 3, 5, 6), 3, [(2, 3, 5, 3), (2, 3, 5, 3)], Inplace.OUT_OF_PLACE), # 沿第3轴拼接
162-
((2, 3, 5, 15), 3, [(2, 3, 5, 3), (2, 3, 5, 3), (2, 3, 5, 9)], Inplace.OUT_OF_PLACE), # 沿第3轴拼接
163-
# 五维张量拼接
164-
((4, 2, 3, 4, 5), 0, [(1, 2, 3, 4, 5), (3, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE), # 沿第0轴拼接
165-
((2, 4, 3, 2, 5), 1, [(2, 2, 3, 2, 5), (2, 2, 3, 2, 5)], Inplace.OUT_OF_PLACE), # 沿第1轴拼接
166-
((1, 2, 4, 4, 5), 2, [(1, 2, 2, 4, 5), (1, 2, 2, 4, 5)], Inplace.OUT_OF_PLACE), # 沿第2轴拼接
167-
((1, 2, 3, 8, 5), 3, [(1, 2, 3, 4, 5), (1, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE), # 沿第3轴拼接
168-
((1, 2, 3, 4, 5), 4, [(1, 2, 3, 4, 3), (1, 2, 3, 4, 2)], Inplace.OUT_OF_PLACE), # 沿第4轴拼接
169-
((4, 14, 3, 4, 5), 1, [(4, 3, 3, 4, 5), (4, 5, 3, 4, 5), (4, 6, 3, 4, 5)], Inplace.OUT_OF_PLACE), # 沿第1轴拼接
139+
140+
((6,), 0, [(2,), (4,)], Inplace.OUT_OF_PLACE),
141+
142+
((6, 3), 0, [(2, 3), (4, 3)], Inplace.OUT_OF_PLACE),
143+
((3, 6), 1, [(3, 2), (3, 4)], Inplace.OUT_OF_PLACE),
144+
((3, 7), 1, [(3, 2), (3, 4), (3, 1)], Inplace.OUT_OF_PLACE),
145+
((3, 3, 10), 2, [(3, 3, 4), (3, 3, 6)], Inplace.OUT_OF_PLACE),
146+
147+
((4, 3, 6), 0, [(3, 3, 6), (1, 3, 6)], Inplace.OUT_OF_PLACE),
148+
((2, 6, 3), 1, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE),
149+
((2, 3, 6), 2, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE),
150+
151+
((4, 3, 5, 6), 0, [(1, 3, 5, 6), (3, 3, 5, 6)], Inplace.OUT_OF_PLACE),
152+
((2, 5, 5, 6), 1, [(2, 3, 5, 6), (2, 2, 5, 6)], Inplace.OUT_OF_PLACE),
153+
((2, 3, 5, 6), 2, [(2, 3, 2, 6), (2, 3, 3, 6)], Inplace.OUT_OF_PLACE),
154+
((2, 3, 5, 6), 3, [(2, 3, 5, 3), (2, 3, 5, 3)], Inplace.OUT_OF_PLACE),
155+
((2, 3, 5, 15), 3, [(2, 3, 5, 3), (2, 3, 5, 3), (2, 3, 5, 9)], Inplace.OUT_OF_PLACE),
156+
157+
((4, 2, 3, 4, 5), 0, [(1, 2, 3, 4, 5), (3, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE),
158+
((2, 4, 3, 2, 5), 1, [(2, 2, 3, 2, 5), (2, 2, 3, 2, 5)], Inplace.OUT_OF_PLACE),
159+
((1, 2, 4, 4, 5), 2, [(1, 2, 2, 4, 5), (1, 2, 2, 4, 5)], Inplace.OUT_OF_PLACE),
160+
((1, 2, 3, 8, 5), 3, [(1, 2, 3, 4, 5), (1, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE),
161+
((1, 2, 3, 4, 5), 4, [(1, 2, 3, 4, 3), (1, 2, 3, 4, 2)], Inplace.OUT_OF_PLACE),
162+
((4, 14, 3, 4, 5), 1, [(4, 3, 3, 4, 5), (4, 5, 3, 4, 5), (4, 6, 3, 4, 5)], Inplace.OUT_OF_PLACE),
163+
164+
165+
((6,), -1, [(2,), (4,)], Inplace.OUT_OF_PLACE),
166+
167+
((6, 3), -2, [(2, 3), (4, 3)], Inplace.OUT_OF_PLACE),
168+
((3, 6), -1, [(3, 2), (3, 4)], Inplace.OUT_OF_PLACE),
169+
((3, 7), -1, [(3, 2), (3, 4), (3, 1)], Inplace.OUT_OF_PLACE),
170+
((3, 3, 10), -1, [(3, 3, 4), (3, 3, 6)], Inplace.OUT_OF_PLACE),
171+
172+
((4, 3, 6), -3, [(3, 3, 6), (1, 3, 6)], Inplace.OUT_OF_PLACE),
173+
((2, 6, 3), -2, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE),
174+
((2, 3, 6), -1, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE),
175+
176+
((4, 3, 5, 6), -4, [(1, 3, 5, 6), (3, 3, 5, 6)], Inplace.OUT_OF_PLACE),
177+
((2, 5, 5, 6), -3, [(2, 3, 5, 6), (2, 2, 5, 6)], Inplace.OUT_OF_PLACE),
178+
((2, 3, 5, 6), -2, [(2, 3, 2, 6), (2, 3, 3, 6)], Inplace.OUT_OF_PLACE),
179+
((2, 3, 5, 6), -1, [(2, 3, 5, 3), (2, 3, 5, 3)], Inplace.OUT_OF_PLACE),
180+
((2, 3, 5, 15), -1, [(2, 3, 5, 3), (2, 3, 5, 3), (2, 3, 5, 9)], Inplace.OUT_OF_PLACE),
181+
182+
((4, 2, 3, 4, 5), -5, [(1, 2, 3, 4, 5), (3, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE),
183+
((2, 4, 3, 2, 5), -4, [(2, 2, 3, 2, 5), (2, 2, 3, 2, 5)], Inplace.OUT_OF_PLACE),
184+
((1, 2, 4, 4, 5), -3, [(1, 2, 2, 4, 5), (1, 2, 2, 4, 5)], Inplace.OUT_OF_PLACE),
185+
((1, 2, 3, 8, 5), -2, [(1, 2, 3, 4, 5), (1, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE),
186+
((1, 2, 3, 4, 5), -1, [(1, 2, 3, 4, 3), (1, 2, 3, 4, 2)], Inplace.OUT_OF_PLACE),
187+
((4, 14, 3, 4, 5), -4, [(4, 3, 3, 4, 5), (4, 5, 3, 4, 5), (4, 6, 3, 4, 5)], Inplace.OUT_OF_PLACE),
170188
]
171189

172190
args = get_args()
@@ -179,7 +197,7 @@ def test_bang(lib, test_cases):
179197
infiniopTensorDescriptor_t,
180198
POINTER(infiniopTensorDescriptor_t),
181199
c_uint64, # nums_input
182-
c_uint64, # axis
200+
c_int64, # axis
183201
]
184202

185203
lib.infiniopConcat.restype = c_int32

0 commit comments

Comments
 (0)