-
Notifications
You must be signed in to change notification settings - Fork 39
Open
Description
Jax just released its version 0.8.2 https://github.com/jax-ml/jax/releases/tag/jax-v0.8.2, and it causes is_jax_array to no longer work for jit compiled arrays. It's probably caused be this change:
jax's Tracer no longer inherits from jax.Array at runtime. However, jax.Array now uses a custom metaclass such isinstance(x, Array) is true if an object x represents a traced Array. Only some Tracers represent Arrays, so it is not correct for Tracer to inherit from Array.
Example
import array_api_compat as xpc
import jax
import jax.numpy as jnp
@jax.jit
def fn(x):
print(xpc.is_jax_array(x)) # False
return jnp.zeros(x.shape, device=xpc.device(x))
x = jnp.array([1.0, 2.0, 3.0])
print(xpc.is_jax_array(x)) # True
y = fn(x)which yields
Traceback (most recent call last):
File "/home/mschuck/repos/proto/.venv/lib/python3.14/site-packages/jax/_src/core.py", line 1071, in __getattr__
attr = getattr(self.aval, name)
AttributeError: 'ShapedArray' object has no attribute 'device'. Did you mean: 'to_device'?
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/mschuck/repos/proto/proto/rotation_grad.py", line 28, in <module>
y = fn(x)
File "/home/mschuck/repos/proto/proto/rotation_grad.py", line 24, in fn
return jnp.zeros(x.shape, device=xpc.device(x))
~~~~~~~~~~^^^
File "/home/mschuck/repos/proto/.venv/lib/python3.14/site-packages/array_api_compat/common/_helpers.py", line 764, in device
return x.device # pyright: ignore
^^^^^^^^
AttributeError: DynamicJaxprTracer has no attribute device. Did you mean: 'devices'?
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.Relevant Issues
Metadata
Metadata
Assignees
Labels
No labels