Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/host/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ end
function (T::Type{<: AnyGPUArray{U}})(s::UniformScaling, dims::Dims{2}) where {U}
res = similar(T, dims)
fill!(res, zero(U))
isempty(res) && return res
kernel = identity_kernel(get_backend(res))
kernel(res, size(res, 1), s.λ; ndrange=minimum(dims))
res
return res
end

(T::Type{<: AnyGPUArray})(s::UniformScaling{U}, dims::Dims{2}) where U = T{U}(s, dims)
Expand All @@ -48,9 +49,10 @@ end

function Base.copyto!(A::AbstractGPUMatrix{T}, s::UniformScaling) where T
fill!(A, zero(T))
isempty(A) && return A
kernel = identity_kernel(get_backend(A))
kernel(A, size(A, 1), s.λ; ndrange=minimum(size(A)))
A
return A
end

function _one(unit::T, x::AbstractGPUMatrix) where {T}
Expand Down
34 changes: 18 additions & 16 deletions src/host/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,27 +165,29 @@ for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriang
end

function LinearAlgebra.tril!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
@kernel function tril_kernel!(_A, _d)
I = @index(Global, Cartesian)
i, j = Tuple(I)
if i < j - _d
@inbounds _A[i, j] = zero(T)
isempty(A) && return A
@kernel function tril_kernel!(_A, _d)
I = @index(Global, Cartesian)
i, j = Tuple(I)
if i < j - _d
@inbounds _A[i, j] = zero(T)
end
end
end
tril_kernel!(get_backend(A))(A, d; ndrange = size(A))
return A
tril_kernel!(get_backend(A))(A, d; ndrange = size(A))
return A
end

function LinearAlgebra.triu!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
@kernel function triu_kernel!(_A, _d)
I = @index(Global, Cartesian)
i, j = Tuple(I)
if j < i + _d
@inbounds _A[i, j] = zero(T)
isempty(A) && return A
@kernel function triu_kernel!(_A, _d)
I = @index(Global, Cartesian)
i, j = Tuple(I)
if j < i + _d
@inbounds _A[i, j] = zero(T)
end
end
end
triu_kernel!(get_backend(A))(A, d; ndrange = size(A))
return A
triu_kernel!(get_backend(A))(A, d; ndrange = size(A))
return A
end

# check if upper triangular starting from the kth superdiagonal.
Expand Down
10 changes: 10 additions & 0 deletions test/testsuite/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,16 @@
@test Array(x1) ≈ x
end

@testset "empty" begin
x = Matrix{Float32}(I, (0, 3))
x1 = AT{Float32, 2}(I, (0, 3))

@test Array(x1) ≈ x

copyto!(x1, I)
@test Array(x1) ≈ x
end

@testset "JuliaGPU/GPUArrays.jl#439" begin
x = AT{Float32}(I, 500, 300)
y = Array{Float32}(I, 500, 300)
Expand Down
7 changes: 5 additions & 2 deletions test/testsuite/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,13 @@
@test_throws SingularException ldiv!(D, B)
end

@testset "$f! with diagonal $d" for (f, f!) in ((triu, triu!), (tril, tril!)),
@testset "$f with diagonal $d" for f in (triu, triu!, tril, tril!),
d in -2:2
A = randn(Float32, 10, 10)
@test f(A, d) == Array(f!(AT(A), d))
@test compare(f, AT, A, d)

A_empty = randn(Float32, 0, 0)
@test compare(f, AT, A_empty, d)
end
end

Expand Down
Loading