@@ -67,7 +67,8 @@ function norm_helper(x, μ, σ², scale::Union{AbstractArray, Nothing},
6767 error (" both scale and bias must be provided or left as nothing" )
6868 end
6969 scale′, bias′ = _maybe_reshape (scale, affine_size), _maybe_reshape (bias, affine_size)
70- return _apply_scale_bias ((x .- μ) ./ sqrt .(σ² .+ ϵ), scale′, bias′)
70+ denom = inv .(sqrt .(σ² .+ ϵ))
71+ return _apply_scale_bias ((x .- μ) .* denom, scale′, bias′)
7172end
7273
7374"""
7677Contains running mean and variance estimates for stateful norm functions.
7778`momentum` controls the strength of the moving average update.
7879
79- If the parameters are mutable, they will be updated in-place.
80- Otherwise, they will be replaced wholesale.
80+ Parameters should be mutable and will be updated in-place.
8181
8282See also [`update_running_stats!`](@ref).
8383"""
84- mutable struct RunningStats{M <: AbstractArray , V <: AbstractArray , MT <: Real }
84+ struct RunningStats{M <: AbstractArray , V <: AbstractArray , MT <: Real }
8585 mean:: M
8686 variance:: V
8787 momentum:: MT
@@ -142,16 +142,9 @@ function update_running_stats!(stats::RunningStats, x, μ, σ², reduce_dims::Di
142142 correction = m / (m - one (V))
143143
144144 running_mean, running_var = stats. mean, stats. variance
145- if ChainRulesCore. is_inplaceable_destination (running_mean)
146- stats. mean .= res_mtm .* running_mean .+ momentum .* vec (μ)
147- else
148- stats. mean = res_mtm .* running_mean .+ momentum .* vec (μ)
149- end
150- if ChainRulesCore. is_inplaceable_destination (running_var)
151- stats. variance .= res_mtm .* running_var .+ momentum .* correction .* vec (σ²)
152- else
153- stats. variance = res_mtm .* running_var .+ momentum .* correction .* vec (σ²)
154- end
145+ stats. mean .= res_mtm .* running_mean .+ momentum .* vec (μ)
146+ stats. variance .= res_mtm .* running_var .+ momentum .* correction .* vec (σ²)
147+ return
155148end
156149
157150# Convenience functions
@@ -190,7 +183,7 @@ function layernorm(x::AbstractArray{<:Any, N}, ::Val{S}, scale = nothing, bias =
190183 throw (DimensionMismatch (" got $S reduction dims for $N -dimensional array" ))
191184 end
192185 μ, σ² = norm_stats (x, ntuple (identity, S))
193- return norm_helper (x, μ, σ², scale, bias, ϵ, size (x)[1 : S])
186+ return norm_helper (x, μ, σ², scale, bias, ϵ, size (x)[1 : S]:: Dims{S} )
194187end
195188
196189"""
0 commit comments