Skip to content

Commit f74eca1

Browse files
Merge pull request #136 from PanZezhong1725/random_sample_test
fix: random sample测试使用确定的分布
2 parents 9214a89 + f43df84 commit f74eca1

File tree

2 files changed

+18
-19
lines changed

2 files changed

+18
-19
lines changed

.github/workflows/main.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ jobs:
2323

2424
- name: Install Python dependencies
2525
run: |
26+
pip install numpy
2627
pip install torch
2728
2829
- name: Install xmake

operatorspy/tests/random_sample.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
6363
else:
6464
end = topk
6565

66-
67-
6866
sum_s = 0
6967
for i in range(end):
7068
sum_s += dataNp[i]
@@ -78,12 +76,14 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
7876

7977
def random_sample_0(data):
8078
return torch.argmax(data)
79+
8180
def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_dtype=torch.float16):
8281
print(
8382
f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}"
8483
)
85-
86-
data = torch.rand((voc), dtype=x_dtype).to(torch_device)
84+
data = torch.arange(voc).float() * 0.0001
85+
_perm = torch.randperm(voc)
86+
data = data[_perm].to(x_dtype).to(torch_device)
8787
if(topp > 0 and topk > 1):
8888
ans = random_sample(data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu")
8989
else:
@@ -130,12 +130,9 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
130130
if torch_device == "npu":
131131
torch.npu.synchronize()
132132

133-
assert indices[0].type(ans.dtype) == ans or abs(data[indices[0]] - data[ans]) == 0.0, "compute error"
134-
135-
136-
133+
assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]]
137134
check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor))
138-
135+
print("Test passed!")
139136

140137
def test_cpu(lib, test_cases):
141138
device = DeviceEnum.DEVICE_CPU
@@ -176,15 +173,16 @@ def test_ascend(lib, test_cases):
176173
if __name__ == "__main__":
177174
test_cases = [
178175
# voc, random_val, topp, topk, temperature
179-
(512, 0.92, 0.8, 3, 0.5),
180-
(4096, 0.95, 0.9, 5, 1.0),
181-
(16384, 0.85, 0.85, 10, 2.0),
182-
(512, 0.92, 0, 3, 0.5),
183-
(4096, 0.95, 0.9, 1, 1.0),
184-
(16384, 0.85, 0, 1, 2.0),
185-
(16384, 0.85, 0, 1, 2.0),
186-
(32000, 0.8, 0.8, 50, 1.0),
187-
(32000, 0.8, 1.0, 25, 1.0),
176+
(512, 0.8, 0.8, 3, 0.5),
177+
(4096, 0.05, 0.9, 5, 1.0),
178+
(16384, 0.15, 0.85, 10, 2.0),
179+
(512, 0.08, 0, 3, 0.5),
180+
(4096, 0.5, 0.9, 1, 1.0),
181+
(16384, 0.15, 0, 1, 2.0),
182+
(16384, 0.15, 0, 1, 2.0),
183+
(32000, 0.08, 0.8, 50, 1.0),
184+
(32000, 0.08, 1.0, 25, 1.0),
185+
# (119696, 0.01, 1.0, 100, 1.0),
188186
]
189187

190188
args = get_args()
@@ -228,4 +226,4 @@ def test_ascend(lib, test_cases):
228226
test_ascend(lib, test_cases)
229227
if not (args.cpu or args.cuda or args.bang or args.ascend):
230228
test_cpu(lib, test_cases)
231-
print("Test passed!")
229+
print("\033[92mTest passed!\033[0m")

0 commit comments

Comments
 (0)