From b05ba77f931a43053823fabc7a0c9c653b5c1a4d Mon Sep 17 00:00:00 2001 From: Clemens Giuliani Date: Sat, 22 Nov 2025 01:01:34 +0100 Subject: [PATCH 1/5] add missing pvary --- folx/ad.py | 1 + 1 file changed, 1 insertion(+) diff --git a/folx/ad.py b/folx/ad.py index f1dc97d..e9481a8 100644 --- a/folx/ad.py +++ b/folx/ad.py @@ -94,6 +94,7 @@ def jvp_fun(s): return jax.jvp(f, primals, unravel(s))[1] eye = jnp.eye(flat_primals.size, dtype=flat_primals.dtype) + eye = jax.lax.pvary(eye, tuple(jax.typeof(flat_primals).vma)) J = jax.vmap(jvp_fun, out_axes=-1)(eye) return J From 4ad5eab377d48fc0eccdbfaa78077f43a4c9d348 Mon Sep 17 00:00:00 2001 From: Clemens Giuliani Date: Sat, 22 Nov 2025 01:07:21 +0100 Subject: [PATCH 2/5] add regression test --- test/test_shard_map.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 test/test_shard_map.py diff --git a/test/test_shard_map.py b/test/test_shard_map.py new file mode 100644 index 0000000..f764523 --- /dev/null +++ b/test/test_shard_map.py @@ -0,0 +1,24 @@ +from functools import partial + +import jax +import jax.numpy as jnp + +from folx import forward_laplacian + + +def test_shard_map_bug_integer_pow(): + # see https://github.com/microsoft/folx/issues/38 + + def f(w, x): + return jax.lax.integer_pow(x @ w, 1) + + @jax.smap(out_axes=0, in_axes=(None, 0), axis_name='i') + @partial(jax.vmap, in_axes=(None, 0)) + def test(w, x): + return forward_laplacian(partial(f, w))(x) + + x = jnp.ones((1, 16)) + w = jnp.ones((16, 16)) + + with jax.set_mesh(jax.sharding.Mesh(jax.devices(), 'i')): + test(w, x) From cf13833dad802ce03b27424242e44048cd5c1aee Mon Sep 17 00:00:00 2001 From: Clemens Giuliani Date: Tue, 25 Nov 2025 14:24:37 +0100 Subject: [PATCH 3/5] support old jax --- folx/ad.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/folx/ad.py b/folx/ad.py index e9481a8..6e06fc0 100644 --- a/folx/ad.py +++ b/folx/ad.py @@ -94,7 +94,8 @@ def jvp_fun(s): return jax.jvp(f, primals, unravel(s))[1] eye = jnp.eye(flat_primals.size, dtype=flat_primals.dtype) - eye = jax.lax.pvary(eye, tuple(jax.typeof(flat_primals).vma)) + if hasattr(jax.lax, 'pvary'): + eye = jax.lax.pvary(eye, tuple(jax.typeof(flat_primals).vma)) J = jax.vmap(jvp_fun, out_axes=-1)(eye) return J From 773d61054de7d356aa0981ed179f7c1bed3787ec Mon Sep 17 00:00:00 2001 From: Clemens Giuliani Date: Tue, 25 Nov 2025 14:25:01 +0100 Subject: [PATCH 4/5] test on a single device only (its sufficient to trigger the error) --- test/test_shard_map.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_shard_map.py b/test/test_shard_map.py index f764523..ec9d6d4 100644 --- a/test/test_shard_map.py +++ b/test/test_shard_map.py @@ -20,5 +20,5 @@ def test(w, x): x = jnp.ones((1, 16)) w = jnp.ones((16, 16)) - with jax.set_mesh(jax.sharding.Mesh(jax.devices(), 'i')): + with jax.set_mesh(jax.sharding.Mesh(jax.devices()[:1], 'i')): test(w, x) From 372494b9846f4a52ef5f7ea6cd60ed1b57f47769 Mon Sep 17 00:00:00 2001 From: Clemens Giuliani Date: Tue, 25 Nov 2025 14:25:34 +0100 Subject: [PATCH 5/5] skip test on old jax --- test/test_shard_map.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/test_shard_map.py b/test/test_shard_map.py index ec9d6d4..bd0b73f 100644 --- a/test/test_shard_map.py +++ b/test/test_shard_map.py @@ -2,10 +2,15 @@ import jax import jax.numpy as jnp +import pytest +from packaging.version import Version from folx import forward_laplacian +@pytest.mark.skipif( + Version(jax.__version__) < Version('0.7.1'), reason='jax version too old' +) def test_shard_map_bug_integer_pow(): # see https://github.com/microsoft/folx/issues/38