diff --git a/jaxonnxruntime/onnx_ops/scatterelements.py b/jaxonnxruntime/onnx_ops/scatterelements.py index 5538a3f..55eeb98 100644 --- a/jaxonnxruntime/onnx_ops/scatterelements.py +++ b/jaxonnxruntime/onnx_ops/scatterelements.py @@ -96,8 +96,10 @@ def onnx_scatterelements(*input_args, axis, reduction): """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#ScatterElements for more details.""" data, indices, updates = input_args - idx = jnp.meshgrid( - *(jnp.arange(n) for n in data.shape), sparse=True, indexing="ij" + idx = list( + jnp.meshgrid( + *(jnp.arange(n) for n in data.shape), sparse=True, indexing="ij" + ) ) idx[axis] = indices out = getattr(data.at[tuple(idx)], reduction)( diff --git a/jaxonnxruntime/onnx_ops/scatternd.py b/jaxonnxruntime/onnx_ops/scatternd.py index 954e39f..c569dce 100644 --- a/jaxonnxruntime/onnx_ops/scatternd.py +++ b/jaxonnxruntime/onnx_ops/scatternd.py @@ -141,10 +141,12 @@ def onnx_scatternd(*input_args, reduction: str): # e.g., for (r-k)=3 and z updates: # [(1,1,1,1), (1,range(x1),1,1), (1,1,range(x2),1), (1,1,1,range(x3))] # L----------L------------------L------------------L----> dim for z - idx = jnp.meshgrid( - *(jnp.arange(n) for n in [1] + list(data.shape[k:])), - sparse=True, - indexing="ij", + idx = list( + jnp.meshgrid( + *(jnp.arange(n) for n in [1] + list(data.shape[k:])), + sparse=True, + indexing="ij", + ) ) assert idx[0].ndim == (r - k) + 1