From 6a289c764c0c160c0be82df79d623fae33f34445 Mon Sep 17 00:00:00 2001 From: Chapaman Date: Tue, 16 Sep 2025 06:53:12 -0300 Subject: [PATCH 1/4] Fix reshape example in quickstart.livemd --- nx/guides/getting_started/quickstart.livemd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nx/guides/getting_started/quickstart.livemd b/nx/guides/getting_started/quickstart.livemd index c47cf3de90..f772893515 100644 --- a/nx/guides/getting_started/quickstart.livemd +++ b/nx/guides/getting_started/quickstart.livemd @@ -162,7 +162,7 @@ Nx.shape(tensor) We can also create a new tensor with the given shape using `Nx.reshape/2`: ```elixir -Nx.reshape(tensor, {1, 4}, names: [:batches, :values]) +Nx.reshape(tensor, {1, 6}, names: [:batches, :values]) ``` This operation generally reuses all of the tensor data and simply From f15e7cd900292e58598085f8a960700bb9630918 Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Mon, 27 Oct 2025 08:18:05 -0300 Subject: [PATCH 2/4] now we get the device from nx instead of just defaulting to the cpu --- torchx/lib/torchx/backend.ex | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index 77308ce94d..95514c5d47 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -1528,13 +1528,15 @@ defmodule Torchx.Backend do |> then(unfold_flat) |> then(function) + {device, _} = from_nx(tensor) + indices_to_flatten = tensor |> Nx.axes() |> Enum.map(fn axis -> tensor |> Nx.shape() - |> Nx.iota(axis: axis, backend: Torchx.Backend) + |> Nx.iota(axis: axis, backend: {Torchx.Backend, device: device}) |> then(unfold_flat) |> Nx.take_along_axis(Nx.new_axis(arg_idx, -1), axis: -1) end) From 8b8b6a6b29f6fa241ba43350afe0eaa01ea9327e Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Mon, 3 Nov 2025 08:17:46 -0300 Subject: [PATCH 3/4] added test for indices_to_flatten --- torchx/test/torchx/device_test.exs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torchx/test/torchx/device_test.exs b/torchx/test/torchx/device_test.exs index 0f55a5366a..1db1451796 100644 --- a/torchx/test/torchx/device_test.exs +++ b/torchx/test/torchx/device_test.exs @@ -45,4 +45,12 @@ defmodule Torchx.DeviceTest do # assert_raise ArgumentError, fn -> Nx.backend_transfer(t) end end end + + describe "indices_to_flatten" do + test "works" do + t = Nx.tensor([[1, 2], [3, 4]], backend: {TB, device: @device}) + t2 = Nx.tensor([[2, 6], [3, 1]], backend: {TB, device: @device}) + assert Nx.window_scatter_max(t, t2, 0, {2, 3}, backend: {TB, device: @device}) == Nx.tensor([[0, 0, 0, 0, 6, 0], [0, 0, 2, 0, 0, 0], [0, 0, 3, 0, 0, 0], [0, 0, 0, 0, 0, 1]], backend: {TB, device: @device}) + end + end end From b96581d8a54dcbe11c131ffb13084eff7521e24c Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Mon, 3 Nov 2025 08:23:45 -0300 Subject: [PATCH 4/4] fixed test by using assert_equal --- torchx/test/torchx/device_test.exs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchx/test/torchx/device_test.exs b/torchx/test/torchx/device_test.exs index 1db1451796..32e717ee22 100644 --- a/torchx/test/torchx/device_test.exs +++ b/torchx/test/torchx/device_test.exs @@ -50,7 +50,7 @@ defmodule Torchx.DeviceTest do test "works" do t = Nx.tensor([[1, 2], [3, 4]], backend: {TB, device: @device}) t2 = Nx.tensor([[2, 6], [3, 1]], backend: {TB, device: @device}) - assert Nx.window_scatter_max(t, t2, 0, {2, 3}, backend: {TB, device: @device}) == Nx.tensor([[0, 0, 0, 0, 6, 0], [0, 0, 2, 0, 0, 0], [0, 0, 3, 0, 0, 0], [0, 0, 0, 0, 0, 1]], backend: {TB, device: @device}) + assert_equal Nx.window_scatter_max(t, t2, 0, {2, 3}), Nx.tensor([[0, 0, 0, 0, 6, 0], [0, 0, 2, 0, 0, 0], [0, 0, 3, 0, 0, 0], [0, 0, 0, 0, 0, 1]], backend: {TB, device: @device}) end end end