Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions jaxonnxruntime/onnx_ops/scatterelements.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,10 @@ def onnx_scatterelements(*input_args, axis, reduction):
"""https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#ScatterElements for more details."""
data, indices, updates = input_args

idx = jnp.meshgrid(
*(jnp.arange(n) for n in data.shape), sparse=True, indexing="ij"
idx = list(
jnp.meshgrid(
*(jnp.arange(n) for n in data.shape), sparse=True, indexing="ij"
)
)
idx[axis] = indices
out = getattr(data.at[tuple(idx)], reduction)(
Expand Down
10 changes: 6 additions & 4 deletions jaxonnxruntime/onnx_ops/scatternd.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,12 @@ def onnx_scatternd(*input_args, reduction: str):
# e.g., for (r-k)=3 and z updates:
# [(1,1,1,1), (1,range(x1),1,1), (1,1,range(x2),1), (1,1,1,range(x3))]
# L----------L------------------L------------------L----> dim for z
idx = jnp.meshgrid(
*(jnp.arange(n) for n in [1] + list(data.shape[k:])),
sparse=True,
indexing="ij",
idx = list(
jnp.meshgrid(
*(jnp.arange(n) for n in [1] + list(data.shape[k:])),
sparse=True,
indexing="ij",
)
)
assert idx[0].ndim == (r - k) + 1

Expand Down
Loading