Skip to content

Commit 0bdde99

Browse files
committed
update
1 parent 9d25304 commit 0bdde99

File tree

5 files changed

+71
-240
lines changed

5 files changed

+71
-240
lines changed

operatorspy/liboperators.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
Device = c_int
99
Optype = c_int
1010

11-
LIB_OPERATORS_DIR = os.path.join(os.environ.get("INFINI_ROOT"), "lib")
12-
11+
LIB_OPERATORS_DIR = os.path.join(os.environ.get("INFINI_ROOT"))
1312

1413
class TensorDescriptor(Structure):
1514
_fields_ = [
@@ -19,10 +18,8 @@ class TensorDescriptor(Structure):
1918
("pattern", POINTER(c_int64)),
2019
]
2120

22-
2321
infiniopTensorDescriptor_t = ctypes.POINTER(TensorDescriptor)
2422

25-
2623
class CTensor:
2724
def __init__(self, desc, data):
2825
self.descriptor = desc

operatorspy/tests/concat.py

Lines changed: 35 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import sys
44
import os
55

6-
# 调整路径以导入 operatorspy 模块
76
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
87
from operatorspy import (
98
open_lib,
@@ -24,10 +23,6 @@
2423
class Inplace(Enum):
2524
OUT_OF_PLACE = auto()
2625
# 对于 concat 算子,通常不支持 in-place 操作,因此这里只保留 OUT_OF_PLACE
27-
# 你可以根据实际需求扩展其他选项
28-
# INPLACE_A = auto()
29-
# INPLACE_B = auto()
30-
3126

3227
class ConcatDescriptor(Structure):
3328
_fields_ = [("device", c_int32),]
@@ -58,28 +53,24 @@ def test(
5853
f"Testing Concat on {torch_device} with output_shape:{c_shape}, input_shapes:{input_shapes}, axis:{axis}, dtype:{tensor_dtype}, inplace: {inplace.name}"
5954
)
6055

61-
# 创建输入张量
6256
inputs = [torch.rand(shape, dtype=tensor_dtype).to(torch_device) for shape in input_shapes]
6357

6458
for idx, tensor in enumerate(inputs):
6559
print(f"Input {idx}:")
6660
print(tensor)
6761
print("-" * 50)
6862

69-
# 创建输出张量
7063
if inplace == Inplace.OUT_OF_PLACE:
7164
c = torch.zeros(c_shape, dtype=tensor_dtype).to(torch_device)
7265
else:
7366
# 对于 concat,通常不支持 in-place 操作,因此这里简化为 OUT_OF_PLACE
7467
c = torch.zeros(c_shape, dtype=tensor_dtype).to(torch_device)
7568

76-
# 使用 PyTorch 进行拼接,作为参考答案
7769
ans = concat_py(*inputs, dim=axis)
7870

7971
print("ans:",ans)
8072
print("-" * 50)
8173

82-
# 将张量转换为 infiniop 所需的格式
8374
input_tensors = [to_tensor(t, lib) for t in inputs]
8475
c_tensor = to_tensor(c, lib) if inplace == Inplace.OUT_OF_PLACE else to_tensor(c, lib)
8576

@@ -91,22 +82,17 @@ def test(
9182
input_desc_array_type = infiniopTensorDescriptor_t * num_inputs
9283
input_desc_array = input_desc_array_type(*[t.descriptor for t in input_tensors])
9384

94-
# 创建描述符
9585
check_error(
9686
lib.infiniopCreateConcatDescriptor(
9787
handle,
9888
ctypes.byref(descriptor),
99-
c_tensor.descriptor, # 使用 c_tensor 的描述符
100-
input_desc_array, # 输入张量描述符数组
89+
c_tensor.descriptor,
90+
input_desc_array,
10191
c_uint64(num_inputs),
10292
c_uint64(axis),
10393
)
10494
)
10595

106-
print("c1:",c)
107-
print("-" * 50)
108-
109-
# 执行拼接操作
11096
input_data_ptrs = (c_void_p * num_inputs)(*[t.data for t in input_tensors])
11197
check_error(
11298
lib.infiniopConcat(
@@ -121,9 +107,8 @@ def test(
121107
print("-" * 50)
122108

123109
# 验证结果
124-
assert torch.allclose(c, ans, atol=0, rtol=1e-5), "Concat result does not match PyTorch's result."
110+
assert torch.allclose(c, ans, atol=0, rtol=0), "Concat result does not match PyTorch's result."
125111

126-
# 销毁描述符
127112
check_error(lib.infiniopDestroyConcatDescriptor(descriptor))
128113

129114

@@ -157,50 +142,59 @@ def test_bang(lib, test_cases):
157142
# 定义测试用例
158143
test_cases = [
159144
# (output_shape, axis, input_shapes, inplace)
160-
161-
((6, 3), 0, [(2, 3), (4, 3)], Inplace.OUT_OF_PLACE),
162-
# ((3, 6), 1, [(3, 2), (3, 4)], Inplace.OUT_OF_PLACE),
163-
# ((3, 7), 1, [(3, 2), (3, 4), (3,1)], Inplace.OUT_OF_PLACE),
164-
# ((3, 3, 10), 2, [(3, 3, 4), (3, 3, 6)], Inplace.OUT_OF_PLACE),
165-
# ((1, 1), 0, [(1, 1)], Inplace.OUT_OF_PLACE),
166-
# ((4, 5, 6), 0, [(1, 5, 6), (3, 5, 6)], Inplace.OUT_OF_PLACE),
167-
# ((2, 3, 6), 2, [(2, 3, 2), (2, 3, 4)], Inplace.OUT_OF_PLACE),
168-
169-
# 添加更多测试用例以覆盖不同的维度和拼接轴
170-
# ((2, 10, 3), 1, [(2, 5, 3), (2, 2, 3),(2,3,3)], Inplace.OUT_OF_PLACE), # 拼接沿第二维
145+
146+
# 一维张量拼接
147+
((6,), 0, [(2,), (4,)], Inplace.OUT_OF_PLACE), # 沿第0轴拼接
148+
# 二维张量拼接
149+
((6, 3), 0, [(2, 3), (4, 3)], Inplace.OUT_OF_PLACE), # 沿第0轴拼接
150+
((3, 6), 1, [(3, 2), (3, 4)], Inplace.OUT_OF_PLACE), # 沿第1轴拼接
151+
((3, 7), 1, [(3, 2), (3, 4), (3, 1)], Inplace.OUT_OF_PLACE), # 沿第1轴拼接
152+
((3, 3, 10), 2, [(3, 3, 4), (3, 3, 6)], Inplace.OUT_OF_PLACE), # 沿第2轴拼接
153+
# 三维张量拼接
154+
((4, 3, 6), 0, [(3, 3, 6), (1, 3, 6)], Inplace.OUT_OF_PLACE), # 沿第0轴拼接
155+
((2, 6, 3), 1, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE), # 沿第1轴拼接
156+
((2, 3, 6), 2, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE), # 沿第2轴拼接
157+
# 四维张量拼接
158+
((4, 3, 5, 6), 0, [(1, 3, 5, 6), (3, 3, 5, 6)], Inplace.OUT_OF_PLACE), # 沿第0轴拼接
159+
((2, 5, 5, 6), 1, [(2, 3, 5, 6), (2, 2, 5, 6)], Inplace.OUT_OF_PLACE), # 沿第1轴拼接
160+
((2, 3, 5, 6), 2, [(2, 3, 2, 6), (2, 3, 3, 6)], Inplace.OUT_OF_PLACE), # 沿第2轴拼接
161+
((2, 3, 5, 6), 3, [(2, 3, 5, 3), (2, 3, 5, 3)], Inplace.OUT_OF_PLACE), # 沿第3轴拼接
162+
((2, 3, 5, 15), 3, [(2, 3, 5, 3), (2, 3, 5, 3), (2, 3, 5, 9)], Inplace.OUT_OF_PLACE), # 沿第3轴拼接
163+
# 五维张量拼接
164+
((4, 2, 3, 4, 5), 0, [(1, 2, 3, 4, 5), (3, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE), # 沿第0轴拼接
165+
((2, 4, 3, 2, 5), 1, [(2, 2, 3, 2, 5), (2, 2, 3, 2, 5)], Inplace.OUT_OF_PLACE), # 沿第1轴拼接
166+
((1, 2, 4, 4, 5), 2, [(1, 2, 2, 4, 5), (1, 2, 2, 4, 5)], Inplace.OUT_OF_PLACE), # 沿第2轴拼接
167+
((1, 2, 3, 8, 5), 3, [(1, 2, 3, 4, 5), (1, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE), # 沿第3轴拼接
168+
((1, 2, 3, 4, 5), 4, [(1, 2, 3, 4, 3), (1, 2, 3, 4, 2)], Inplace.OUT_OF_PLACE), # 沿第4轴拼接
169+
((4, 14, 3, 4, 5), 1, [(4, 3, 3, 4, 5), (4, 5, 3, 4, 5), (4, 6, 3, 4, 5)], Inplace.OUT_OF_PLACE), # 沿第1轴拼接
171170
]
172171

173172
args = get_args()
174173
lib = open_lib()
175174

176-
# 绑定 C++ 函数
177-
# 创建 Concat 描述符
178175
lib.infiniopCreateConcatDescriptor.restype = c_int32
179176
lib.infiniopCreateConcatDescriptor.argtypes = [
180177
infiniopHandle_t,
181178
POINTER(infiniopConcatDescriptor_t),
182-
infiniopTensorDescriptor_t, # 输出张量描述符
183-
POINTER(infiniopTensorDescriptor_t), # 输入张量描述符数组
184-
c_uint64, # 输入张量数量
185-
c_uint64, # 拼接轴
179+
infiniopTensorDescriptor_t,
180+
POINTER(infiniopTensorDescriptor_t),
181+
c_uint64, # nums_input
182+
c_uint64, # axis
186183
]
187184

188-
# 执行 Concat
189185
lib.infiniopConcat.restype = c_int32
190186
lib.infiniopConcat.argtypes = [
191187
infiniopConcatDescriptor_t,
192-
c_void_p, # 输出数据指针
193-
POINTER(c_void_p), # 输入数据指针数组
194-
c_void_p, # 流(假设为 NULL)
188+
c_void_p,
189+
POINTER(c_void_p),
190+
c_void_p,
195191
]
196192

197-
# 销毁 Concat 描述符
198193
lib.infiniopDestroyConcatDescriptor.restype = c_int32
199194
lib.infiniopDestroyConcatDescriptor.argtypes = [
200195
infiniopConcatDescriptor_t,
201196
]
202197

203-
# 根据命令行参数执行测试
204198
if args.cpu:
205199
test_cpu(lib, test_cases)
206200
if args.cuda:

0 commit comments

Comments
 (0)