diff --git a/folx/ad.py b/folx/ad.py index f1dc97d..6e06fc0 100644 --- a/folx/ad.py +++ b/folx/ad.py @@ -94,6 +94,8 @@ def jvp_fun(s): return jax.jvp(f, primals, unravel(s))[1] eye = jnp.eye(flat_primals.size, dtype=flat_primals.dtype) + if hasattr(jax.lax, 'pvary'): + eye = jax.lax.pvary(eye, tuple(jax.typeof(flat_primals).vma)) J = jax.vmap(jvp_fun, out_axes=-1)(eye) return J diff --git a/test/test_shard_map.py b/test/test_shard_map.py new file mode 100644 index 0000000..bd0b73f --- /dev/null +++ b/test/test_shard_map.py @@ -0,0 +1,29 @@ +from functools import partial + +import jax +import jax.numpy as jnp +import pytest +from packaging.version import Version + +from folx import forward_laplacian + + +@pytest.mark.skipif( + Version(jax.__version__) < Version('0.7.1'), reason='jax version too old' +) +def test_shard_map_bug_integer_pow(): + # see https://github.com/microsoft/folx/issues/38 + + def f(w, x): + return jax.lax.integer_pow(x @ w, 1) + + @jax.smap(out_axes=0, in_axes=(None, 0), axis_name='i') + @partial(jax.vmap, in_axes=(None, 0)) + def test(w, x): + return forward_laplacian(partial(f, w))(x) + + x = jnp.ones((1, 16)) + w = jnp.ones((16, 16)) + + with jax.set_mesh(jax.sharding.Mesh(jax.devices()[:1], 'i')): + test(w, x)