Skip to content

jax==0.8.2 breaks is_jax_array and subsequently device #368

@amacati

Description

@amacati

Jax just released its version 0.8.2 https://github.com/jax-ml/jax/releases/tag/jax-v0.8.2, and it causes is_jax_array to no longer work for jit compiled arrays. It's probably caused be this change:

jax's Tracer no longer inherits from jax.Array at runtime. However, jax.Array now uses a custom metaclass such isinstance(x, Array) is true if an object x represents a traced Array. Only some Tracers represent Arrays, so it is not correct for Tracer to inherit from Array.

Example

import array_api_compat as xpc
import jax
import jax.numpy as jnp


@jax.jit
def fn(x):
    print(xpc.is_jax_array(x))  # False
    return jnp.zeros(x.shape, device=xpc.device(x))


x = jnp.array([1.0, 2.0, 3.0])
print(xpc.is_jax_array(x))  # True
y = fn(x)

which yields

Traceback (most recent call last):
  File "/home/mschuck/repos/proto/.venv/lib/python3.14/site-packages/jax/_src/core.py", line 1071, in __getattr__
    attr = getattr(self.aval, name)
AttributeError: 'ShapedArray' object has no attribute 'device'. Did you mean: 'to_device'?

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/mschuck/repos/proto/proto/rotation_grad.py", line 28, in <module>
    y = fn(x)
  File "/home/mschuck/repos/proto/proto/rotation_grad.py", line 24, in fn
    return jnp.zeros(x.shape, device=xpc.device(x))
                                     ~~~~~~~~~~^^^
  File "/home/mschuck/repos/proto/.venv/lib/python3.14/site-packages/array_api_compat/common/_helpers.py", line 764, in device
    return x.device  # pyright: ignore
           ^^^^^^^^
AttributeError: DynamicJaxprTracer has no attribute device. Did you mean: 'devices'?
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Relevant Issues

jax-ml/jax#26000
scipy/scipy#22246

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions