Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1632,7 +1632,7 @@ defmodule Axon.Loop do
final_metrics_map = loop_state.metrics
loop_state = %{loop_state | metrics: zero_metrics}

{status, final_metrics_map, state} =
{status, final_metrics_map, %State{} = state} =
case fire_event(:started, handler_fns, loop_state, debug?) do
{:halt_epoch, state} ->
{:halted, final_metrics_map, state}
Expand Down Expand Up @@ -1691,7 +1691,7 @@ defmodule Axon.Loop do
{:halt_loop, state} ->
{:halt, {final_metrics_map, state}}

{:continue, state} ->
{:continue, %State{} = state} ->
{:cont,
{batch_fn, Map.put(final_metrics_map, epoch, state.metrics),
%State{
Expand Down Expand Up @@ -1922,9 +1922,9 @@ defmodule Axon.Loop do
end

# Halts an epoch during looping
defp halt_epoch(handler_fns, batch_fn, final_metrics_map, loop_state, debug?) do
defp halt_epoch(handler_fns, batch_fn, final_metrics_map, %State{} = loop_state, debug?) do
case fire_event(:epoch_halted, handler_fns, loop_state, debug?) do
{:halt_epoch, %{epoch: epoch, metrics: metrics} = state} ->
{:halt_epoch, %State{epoch: epoch, metrics: metrics} = state} ->
final_metrics_map = Map.put(final_metrics_map, epoch, metrics)
{:cont, {batch_fn, final_metrics_map, %State{state | epoch: epoch + 1, iteration: 0}}}

Expand Down
5 changes: 2 additions & 3 deletions lib/axon/quantization/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,11 @@ defmodule Axon.Quantization.Layers do

deftransformp reshape_scales(scales, y) do
ones = List.to_tuple(List.duplicate(1, Nx.rank(y) - 1))
Nx.reshape(scales, Tuple.append(ones, :auto))
Nx.reshape(scales, :erlang.append_element(ones, :auto))
end

deftransformp reshape_output(output, x_shape) do
all_but_last = Tuple.delete_at(x_shape, tuple_size(x_shape) - 1)
new_shape = Tuple.append(all_but_last, :auto)
Nx.reshape(output, new_shape)
Nx.reshape(output, :erlang.append_element(all_but_last, :auto))
end
end
13 changes: 8 additions & 5 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@ defmodule Axon.MixProject do
deps: deps(),
docs: docs(),
description: "Create and train neural networks in Elixir",
package: package(),
preferred_cli_env: [
docs: :docs,
"hex.publish": :docs
]
package: package()
]
end

def cli do
[
docs: :docs,
"hex.publish": :docs
]
end
Comment on lines +17 to 26
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I this change correct for < 1.19?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure honestly, isn't this local-only thing and just a pure function otherwise? I thought it won't break even if not supported

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We suport this for a few versions already and yes, it won't break users.


Expand Down
Loading