Skip to content

Commit 2bb97dc

Browse files
committed
Fix pprint error, add description of attention configuration params
1 parent e082175 commit 2bb97dc

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/schedulers/test_scheduler_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def test_full_loop_no_noise(self):
335335
result_mean = jnp.mean(jnp.abs(sample))
336336

337337
if jax_device == "tpu":
338-
assert abs(result_sum - 257.29) < 1.5e-2
338+
assert abs(result_sum - 263.11) < 1.5e-2
339339
assert abs(result_mean - 0.3349905) < 2e-5
340340
else:
341341
assert abs(result_sum - 255.1113) < 1e-2

0 commit comments

Comments
 (0)