diff --git a/lib/axon/loop.ex b/lib/axon/loop.ex index 3b43fdb78..437f57992 100644 --- a/lib/axon/loop.ex +++ b/lib/axon/loop.ex @@ -1632,7 +1632,7 @@ defmodule Axon.Loop do final_metrics_map = loop_state.metrics loop_state = %{loop_state | metrics: zero_metrics} - {status, final_metrics_map, state} = + {status, final_metrics_map, %State{} = state} = case fire_event(:started, handler_fns, loop_state, debug?) do {:halt_epoch, state} -> {:halted, final_metrics_map, state} @@ -1691,7 +1691,7 @@ defmodule Axon.Loop do {:halt_loop, state} -> {:halt, {final_metrics_map, state}} - {:continue, state} -> + {:continue, %State{} = state} -> {:cont, {batch_fn, Map.put(final_metrics_map, epoch, state.metrics), %State{ @@ -1922,9 +1922,9 @@ defmodule Axon.Loop do end # Halts an epoch during looping - defp halt_epoch(handler_fns, batch_fn, final_metrics_map, loop_state, debug?) do + defp halt_epoch(handler_fns, batch_fn, final_metrics_map, %State{} = loop_state, debug?) do case fire_event(:epoch_halted, handler_fns, loop_state, debug?) do - {:halt_epoch, %{epoch: epoch, metrics: metrics} = state} -> + {:halt_epoch, %State{epoch: epoch, metrics: metrics} = state} -> final_metrics_map = Map.put(final_metrics_map, epoch, metrics) {:cont, {batch_fn, final_metrics_map, %State{state | epoch: epoch + 1, iteration: 0}}} diff --git a/lib/axon/quantization/layers.ex b/lib/axon/quantization/layers.ex index 80900a17e..00a42d978 100644 --- a/lib/axon/quantization/layers.ex +++ b/lib/axon/quantization/layers.ex @@ -46,12 +46,11 @@ defmodule Axon.Quantization.Layers do deftransformp reshape_scales(scales, y) do ones = List.to_tuple(List.duplicate(1, Nx.rank(y) - 1)) - Nx.reshape(scales, Tuple.append(ones, :auto)) + Nx.reshape(scales, :erlang.append_element(ones, :auto)) end deftransformp reshape_output(output, x_shape) do all_but_last = Tuple.delete_at(x_shape, tuple_size(x_shape) - 1) - new_shape = Tuple.append(all_but_last, :auto) - Nx.reshape(output, new_shape) + Nx.reshape(output, :erlang.append_element(all_but_last, :auto)) end end diff --git a/mix.exs b/mix.exs index 52b943579..9e6006a1a 100644 --- a/mix.exs +++ b/mix.exs @@ -14,11 +14,14 @@ defmodule Axon.MixProject do deps: deps(), docs: docs(), description: "Create and train neural networks in Elixir", - package: package(), - preferred_cli_env: [ - docs: :docs, - "hex.publish": :docs - ] + package: package() + ] + end + + def cli do + [ + docs: :docs, + "hex.publish": :docs ] end