|
12 | 12 | from pytensor.compile.mode import get_default_mode, get_mode |
13 | 13 | from pytensor.compile.ops import DeepCopyOp, deep_copy_op |
14 | 14 | from pytensor.configdefaults import config |
15 | | -from pytensor.graph.basic import equal_computations, vars_between |
| 15 | +from pytensor.graph.basic import equal_computations |
16 | 16 | from pytensor.graph.fg import FunctionGraph |
17 | 17 | from pytensor.graph.rewriting.basic import check_stack_trace, out2in |
18 | 18 | from pytensor.graph.rewriting.db import RewriteDatabaseQuery |
|
26 | 26 | ScalarFromTensor, |
27 | 27 | Split, |
28 | 28 | TensorFromScalar, |
| 29 | + cast, |
29 | 30 | join, |
30 | 31 | tile, |
31 | 32 | ) |
32 | 33 | from pytensor.tensor.elemwise import DimShuffle, Elemwise |
33 | 34 | from pytensor.tensor.math import ( |
34 | | - Sum, |
35 | 35 | add, |
36 | 36 | bitwise_and, |
37 | 37 | bitwise_or, |
@@ -1298,41 +1298,48 @@ def test_local_join_make_vector(): |
1298 | 1298 |
|
1299 | 1299 |
|
1300 | 1300 | def test_local_sum_make_vector(): |
| 1301 | + # To check that rewrite is applied, we must enforce dtype to |
| 1302 | + # allow rewrite to occur even if floatX != "float64" |
1301 | 1303 | a, b, c = scalars("abc") |
1302 | 1304 | mv = MakeVector(config.floatX) |
1303 | | - output = mv(a, b, c).sum() |
1304 | | - |
1305 | | - output = rewrite_graph(output) |
1306 | | - between = vars_between([a, b, c], [output]) |
1307 | | - for var in between: |
1308 | | - assert (var.owner is None) or (not isinstance(var.owner.op, MakeVector)) |
| 1305 | + output = mv(a, b, c).sum(dtype="float64") |
| 1306 | + rewrite_output = rewrite_graph(output) |
| 1307 | + expected_output = cast( |
| 1308 | + add(*[cast(value, "float64") for value in [a, b, c]]), dtype="float64" |
| 1309 | + ) |
| 1310 | + assert equal_computations([expected_output], [rewrite_output]) |
1309 | 1311 |
|
1310 | | - # Check for empty sum |
| 1312 | + # Empty axes should return input vector since no sum is applied |
1311 | 1313 | a, b, c = scalars("abc") |
1312 | 1314 | mv = MakeVector(config.floatX) |
1313 | 1315 | output = mv(a, b, c).sum(axis=[]) |
| 1316 | + rewrite_output = rewrite_graph(output) |
| 1317 | + expected_output = mv(a, b, c) |
| 1318 | + assert equal_computations([expected_output], [rewrite_output]) |
1314 | 1319 |
|
1315 | | - output = rewrite_graph(output) |
1316 | | - between = vars_between([a, b, c], [output]) |
1317 | | - for var in between: |
1318 | | - assert (var.owner is None) or (not isinstance(var.owner.op, Sum)) |
1319 | | - |
1320 | | - # Check empty MakeVector |
| 1320 | + # Empty input should return 0 |
1321 | 1321 | mv = MakeVector(config.floatX) |
1322 | 1322 | output = mv().sum() |
| 1323 | + rewrite_output = rewrite_graph(output) |
| 1324 | + expected_output = pt.as_tensor(0, dtype=config.floatX) |
| 1325 | + assert equal_computations([expected_output], [rewrite_output]) |
1323 | 1326 |
|
1324 | | - output = rewrite_graph(output) |
1325 | | - between = vars_between([a, b, c], [output]) |
1326 | | - for var in between: |
1327 | | - assert (var.owner is None) or (not isinstance(var.owner.op, Sum)) |
1328 | | - |
| 1327 | + # Single element input should return element value |
| 1328 | + a = scalars("a") |
1329 | 1329 | mv = MakeVector(config.floatX) |
1330 | 1330 | output = mv(a).sum() |
1331 | | - |
1332 | | - output = rewrite_graph(output) |
1333 | | - between = vars_between([a, b, c], [output]) |
1334 | | - for var in between: |
1335 | | - assert (var.owner is None) or (not isinstance(var.owner.op, Sum)) |
| 1331 | + rewrite_output = rewrite_graph(output) |
| 1332 | + expected_output = cast(a, config.floatX) |
| 1333 | + assert equal_computations([expected_output], [rewrite_output]) |
| 1334 | + |
| 1335 | + # This is a regression test for #653. Ensure that rewrite is NOT |
| 1336 | + # applied when user requests float32 |
| 1337 | + with config.change_flags(floatX="float32", warn_float64="raise"): |
| 1338 | + a, b, c = scalars("abc") |
| 1339 | + mv = MakeVector(config.floatX) |
| 1340 | + output = mv(a, b, c).sum() |
| 1341 | + rewrite_output = rewrite_graph(output) |
| 1342 | + assert equal_computations([output], [rewrite_output]) |
1336 | 1343 |
|
1337 | 1344 |
|
1338 | 1345 | @pytest.mark.parametrize( |
|
0 commit comments