From aa1be84f15ad8d018e613c799d3e03158acac3df Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 18 Dec 2025 10:21:58 -0800 Subject: [PATCH] Forward-fix for JAX API changes In https://github.com/jax-ml/jax/pull/33984, JAX will begin returning tuples rather than lists for several jax.numpy APIs. This fixes breakages associated with that change. PiperOrigin-RevId: 846323468 --- jaxonnxruntime/onnx_ops/scatterelements.py | 6 ++++-- jaxonnxruntime/onnx_ops/scatternd.py | 10 ++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/jaxonnxruntime/onnx_ops/scatterelements.py b/jaxonnxruntime/onnx_ops/scatterelements.py index 5538a3f..55eeb98 100644 --- a/jaxonnxruntime/onnx_ops/scatterelements.py +++ b/jaxonnxruntime/onnx_ops/scatterelements.py @@ -96,8 +96,10 @@ def onnx_scatterelements(*input_args, axis, reduction): """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#ScatterElements for more details.""" data, indices, updates = input_args - idx = jnp.meshgrid( - *(jnp.arange(n) for n in data.shape), sparse=True, indexing="ij" + idx = list( + jnp.meshgrid( + *(jnp.arange(n) for n in data.shape), sparse=True, indexing="ij" + ) ) idx[axis] = indices out = getattr(data.at[tuple(idx)], reduction)( diff --git a/jaxonnxruntime/onnx_ops/scatternd.py b/jaxonnxruntime/onnx_ops/scatternd.py index 954e39f..c569dce 100644 --- a/jaxonnxruntime/onnx_ops/scatternd.py +++ b/jaxonnxruntime/onnx_ops/scatternd.py @@ -141,10 +141,12 @@ def onnx_scatternd(*input_args, reduction: str): # e.g., for (r-k)=3 and z updates: # [(1,1,1,1), (1,range(x1),1,1), (1,1,range(x2),1), (1,1,1,range(x3))] # L----------L------------------L------------------L----> dim for z - idx = jnp.meshgrid( - *(jnp.arange(n) for n in [1] + list(data.shape[k:])), - sparse=True, - indexing="ij", + idx = list( + jnp.meshgrid( + *(jnp.arange(n) for n in [1] + list(data.shape[k:])), + sparse=True, + indexing="ij", + ) ) assert idx[0].ndim == (r - k) + 1