diff --git a/src/fft.jl b/src/fft.jl index 13fd457..26910bc 100644 --- a/src/fft.jl +++ b/src/fft.jl @@ -8,49 +8,82 @@ const ComplexFloats = Complex{T} where T<:AbstractFloat # The following implements Bluestein's algorithm, following http://www.dsprelated.com/dspbooks/mdft/Bluestein_s_FFT_Algorithm.html # To add more types, add them in the union of the function's signature. -function generic_fft(x::AbstractVector{T}, region::Integer) where T<:AbstractFloats - @assert region == 1 - generic_fft(x) +function generic_fft!(x::AbstractVector{Complex{T}}) where {T<:AbstractFloat} + if ispow2(length(x)) + return generic_fft_pow2!(x) + end + return copyto!(x, generic_fft(x)) end -function generic_fft!(x::AbstractVector{T}, region::Integer=1) where T<:AbstractFloats +generic_fft!(x::AbstractVector) = copyto!(x, generic_fft(x)) + +function generic_fft!(x::AbstractVector{Complex{T}}, region::Integer) where {T<:AbstractFloat} @assert region == 1 - copyto!(x, generic_fft(x)) + generic_fft!(x) end -function generic_fft(x::AbstractVector{T}, region) where {T<:AbstractFloats} - @assert all(==(1), region) - generic_fft(x) +function _generic_fft_first_dim!(x, Ipost) + Threads.@threads for I in Ipost + generic_fft!(@view x[:, I]) + end + x end -function generic_fft!(x::AbstractVector{T}, region) where {T<:AbstractFloats} - @assert all(==(1), region) - copyto!(x, generic_fft(x)) -end +function generic_fft!(x, region::Integer) + @assert 1 <= region <= ndims(x) -function generic_fft!(x::AbstractMatrix{T}, region::Integer) where T<:AbstractFloats - if region == 1 - for j in 1:size(x, 2) - x[:, j] .= generic_fft(@view x[:, j]) - end - else - for k in 1:size(x, 1) - x[k, :] .= generic_fft(@view x[k, :]) + perm = ntuple(ndims(x)) do i + if i == 1 + region + elseif i == region + 1 + else + i end end + + y = permutedims(x, perm) + + Rright = CartesianIndices(size(y)[2:end]) + y = _generic_fft_first_dim!(y, Rright) + + permutedims!(x, y, perm) x end -function generic_fft!(x::AbstractMatrix{T}, region) where T<:AbstractFloats +function generic_fft!(x, region) for r in region generic_fft!(x, r) end x end -generic_fft(x::AbstractMatrix{T}, region::Integer) where T<:AbstractFloats = generic_fft!(copy(x), region) +function generic_fft!(x) + y = similar(x) + z = x + sz = size(x) + perm = ((2:ndims(x))..., 1) -generic_fft(x::AbstractMatrix{T}, region=ntuple(identity, ndims(x))) where T<:AbstractFloats = generic_fft!(copy(x), region) + for r in 1:ndims(x) + Rright = CartesianIndices(size(z)[2:end]) + _generic_fft_first_dim!(z, Rright) + + sz = (sz[2:end]..., sz[1]) + y = reshape(y, sz) + permutedims!(y, z, perm) + z, y = y, z + end + + if isodd(ndims(x)) + x .= z + end + x +end + + +generic_fft(x, region) = generic_fft!(copy(x), region) + +generic_fft(x) = generic_fft!(copy(x)) function generic_fft(x::AbstractVector{T}) where T<:AbstractFloats n = length(x) @@ -97,15 +130,14 @@ end # c_radix2.c in the GNU Scientific Library and four1 in the Numerical Recipes in C. # However, the trigonometric recurrence is improved for greater efficiency. # The algorithm starts with bit-reversal, then divides and conquers in-place. -function generic_fft_pow2!(x::AbstractVector{T}) where T<:AbstractFloat - n,big2=length(x),2one(T) +function generic_fft_pow2!(x::AbstractVector{Complex{T}}) where T<:AbstractFloat + n,big2=2length(x),2one(T) nn,j=n÷2,1 - for i=1:2:n-1 + for i=1:nn if j>i x[j], x[i] = x[i], x[j] - x[j+1], x[i+1] = x[i+1], x[j+1] end - m = nn + m = nn÷2 while m ≥ 2 && j > m j -= m m = m÷2 @@ -115,35 +147,34 @@ function generic_fft_pow2!(x::AbstractVector{T}) where T<:AbstractFloat logn = 2 while logn < n θ=-big2/logn - wtemp = sinpi(θ/2) - wpr, wpi = -2wtemp^2, sinpi(θ) - wr, wi = one(T), zero(T) - for m=1:2:logn-1 - for i=m:2logn:n - j=i+logn - mixr, mixi = wr*x[j]-wi*x[j+1], wr*x[j+1]+wi*x[j] - x[j], x[j+1] = x[i]-mixr, x[i+1]-mixi - x[i], x[i+1] = x[i]+mixr, x[i+1]+mixi + wp = complex(-2sinpi(θ/2)^2, sinpi(θ)) + w = complex(one(T)) + lognn = logn ÷ 2 + for m=1:lognn + for i=m:logn:nn + j=i+lognn + mix = w * x[j] + x[j] = x[i] - mix + x[i] = x[i] + mix end - wr = (wtemp=wr)*wpr-wi*wpi+wr - wi = wi*wpr+wtemp*wpi+wi + w = w * (1 + wp) end - logn = logn << 1 + logn = 2logn end return x end function generic_fft_pow2(x::AbstractVector{Complex{T}}) where T<:AbstractFloat - y = interlace_complex(x) - generic_fft_pow2!(y) - return deinterlace_complex(y) + return generic_fft_pow2!(copy(x)) end generic_fft_pow2(x::AbstractVector{T}) where T<:AbstractFloat = generic_fft_pow2(complex(x)) function generic_ifft_pow2(x::AbstractVector{Complex{T}}) where T<:AbstractFloat - y = interlace_complex(x, -) + y = conj.(x) # always create copy (conj(x) doesn't copy when eltype(x) is real) generic_fft_pow2!(y) - return ldiv!(T(length(x)), deinterlace_complex(y, -)) + N = T(length(x)) + @. y = conj(y) / N + return y end function generic_dct(x::StridedVector{T}, region::Integer) where T<:AbstractFloats diff --git a/test/fft_tests.jl b/test/fft_tests.jl index f48613b..266451e 100644 --- a/test/fft_tests.jl +++ b/test/fft_tests.jl @@ -174,6 +174,31 @@ end @test generic_fft(X,2) ≈ generic_fft(X, 2:2) ≈ generic_fft!(X̃,2) ≈ fft(X,2) @test X̃ ≈ fft(X,2) X̃ = copy(X) - @test generic_fft(X) ≈ generic_fft(X, 1:2) ≈ generic_fft!(X̃,1:2) ≈ fft(X) + @test generic_fft(X) ≈ generic_fft(X, 1:2) ≈ generic_fft!(X̃) ≈ fft(X) @test X̃ ≈ fft(X) + + X = randn(ComplexF64, 5, 6, 7) + for d in 1:3 + X̃ = copy(X) + @test generic_fft(X,d) ≈ generic_fft(X, d:d) ≈ generic_fft!(X̃, (d,)) ≈ fft(X,d) + @test X̃ ≈ fft(X,d) + end + X1 = copy(X) + X2 = copy(X) + @test generic_fft(X) ≈ generic_fft(X, 1:ndims(X)) ≈ generic_fft!(X1, 1:ndims(X1)) ≈ generic_fft!(X2) ≈ fft(X) + @test generic_fft(X, (1,3)) ≈ fft(X, (1,3)) + @test generic_fft(X, (2,3)) ≈ fft(X, (2,3)) + @test generic_fft(X, (1,2)) ≈ fft(X, (1,2)) + @test generic_fft(X, (2,1)) ≈ fft(X, (2,1)) + @test X1 ≈ fft(X) + @test X2 ≈ fft(X) + + N = 32 + A1 = randn(ComplexF64, N) + @allocations generic_fft!(A1) # compile + @test 0 == @allocations generic_fft!(A1) + + A2 = randn(ComplexF64, N, N, N) + @allocations generic_fft!(A2) # compile + @test N+150 > @allocations generic_fft!(A2) # a few allocations is OK end