diff --git a/nx/guides/getting_started/quickstart.livemd b/nx/guides/getting_started/quickstart.livemd index c47cf3de90..f772893515 100644 --- a/nx/guides/getting_started/quickstart.livemd +++ b/nx/guides/getting_started/quickstart.livemd @@ -162,7 +162,7 @@ Nx.shape(tensor) We can also create a new tensor with the given shape using `Nx.reshape/2`: ```elixir -Nx.reshape(tensor, {1, 4}, names: [:batches, :values]) +Nx.reshape(tensor, {1, 6}, names: [:batches, :values]) ``` This operation generally reuses all of the tensor data and simply diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index 77308ce94d..95514c5d47 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -1528,13 +1528,15 @@ defmodule Torchx.Backend do |> then(unfold_flat) |> then(function) + {device, _} = from_nx(tensor) + indices_to_flatten = tensor |> Nx.axes() |> Enum.map(fn axis -> tensor |> Nx.shape() - |> Nx.iota(axis: axis, backend: Torchx.Backend) + |> Nx.iota(axis: axis, backend: {Torchx.Backend, device: device}) |> then(unfold_flat) |> Nx.take_along_axis(Nx.new_axis(arg_idx, -1), axis: -1) end)