Skip to content

Commit 80d24c2

Browse files
committed
Towards MadNLP
1 parent 5c833a7 commit 80d24c2

File tree

10 files changed

+174
-89
lines changed

10 files changed

+174
-89
lines changed

lib/mkl/array.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ const oneAbstractSparseVector{Tv, Ti} = oneAbstractSparseArray{Tv, Ti, 1}
55
const oneAbstractSparseMatrix{Tv, Ti} = oneAbstractSparseArray{Tv, Ti, 2}
66

77
mutable struct oneSparseMatrixCSR{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
8-
handle::matrix_handle_t
8+
handle::Union{Nothing, matrix_handle_t}
99
rowPtr::oneVector{Ti}
1010
colVal::oneVector{Ti}
1111
nzVal::oneVector{Tv}
@@ -14,7 +14,7 @@ mutable struct oneSparseMatrixCSR{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
1414
end
1515

1616
mutable struct oneSparseMatrixCSC{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
17-
handle::matrix_handle_t
17+
handle::Union{Nothing, matrix_handle_t}
1818
colPtr::oneVector{Ti}
1919
rowVal::oneVector{Ti}
2020
nzVal::oneVector{Tv}
@@ -23,7 +23,7 @@ mutable struct oneSparseMatrixCSC{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
2323
end
2424

2525
mutable struct oneSparseMatrixCOO{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
26-
handle::matrix_handle_t
26+
handle::Union{Nothing, matrix_handle_t}
2727
rowInd::oneVector{Ti}
2828
colInd::oneVector{Ti}
2929
nzVal::oneVector{Tv}

lib/mkl/wrappers_blas.jl

Lines changed: 42 additions & 42 deletions
Large diffs are not rendered by default.

lib/mkl/wrappers_sparse.jl

Lines changed: 81 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
function sparse_release_matrix_handle(A::oneAbstractSparseMatrix)
2-
queue = global_queue(context(A.nzVal), device(A.nzVal))
3-
handle_ptr = Ref{matrix_handle_t}(A.handle)
4-
onemklXsparse_release_matrix_handle(sycl_queue(queue), handle_ptr)
2+
if A.handle !== nothing
3+
try
4+
queue = global_queue(context(A.nzVal), device(A.nzVal))
5+
handle_ptr = Ref{matrix_handle_t}(A.handle)
6+
onemklXsparse_release_matrix_handle(sycl_queue(queue), handle_ptr)
7+
# Only synchronize after successful release to ensure completion
8+
synchronize(queue)
9+
catch err
10+
# Don't let finalizer errors crash the program
11+
@warn "Error releasing sparse matrix handle" exception=err
12+
end
13+
end
514
end
615

716
for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int32),
@@ -13,46 +22,72 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
1322
(:onemklZsparse_set_csr_data , :ComplexF64, :Int32),
1423
(:onemklZsparse_set_csr_data_64, :ComplexF64, :Int64))
1524
@eval begin
16-
function oneSparseMatrixCSR(A::SparseMatrixCSC{$elty, $intty})
25+
26+
function oneSparseMatrixCSR(
27+
rowPtr::oneVector{$intty}, colVal::oneVector{$intty},
28+
nzVal::oneVector{$elty}, dims::NTuple{2, Int}
29+
)
1730
handle_ptr = Ref{matrix_handle_t}()
1831
onemklXsparse_init_matrix_handle(handle_ptr)
32+
m, n = dims
33+
nnzA = length(nzVal)
34+
queue = global_queue(context(nzVal), device(nzVal))
35+
# Don't update handle if matrix is empty
36+
if m != 0 && n != 0
37+
$fname(sycl_queue(queue), handle_ptr[], m, n, 'O', rowPtr, colVal, nzVal)
38+
dA = oneSparseMatrixCSR{$elty, $intty}(handle_ptr[], rowPtr, colVal, nzVal, (m, n), nnzA)
39+
finalizer(sparse_release_matrix_handle, dA)
40+
else
41+
dA = oneSparseMatrixCSR{$elty, $intty}(nothing, rowPtr, colVal, nzVal, (m, n), nnzA)
42+
end
43+
return dA
44+
end
45+
46+
function oneSparseMatrixCSC(
47+
colPtr::oneVector{$intty}, rowVal::oneVector{$intty},
48+
nzVal::oneVector{$elty}, dims::NTuple{2, Int}
49+
)
50+
queue = global_queue(context(nzVal), device(nzVal))
51+
handle_ptr = Ref{matrix_handle_t}()
52+
onemklXsparse_init_matrix_handle(handle_ptr)
53+
m, n = dims
54+
nnzA = length(nzVal)
55+
# Don't update handle if matrix is empty
56+
if m != 0 && n != 0
57+
$fname(sycl_queue(queue), handle_ptr[], n, m, 'O', colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ
58+
dA = oneSparseMatrixCSC{$elty, $intty}(handle_ptr[], colPtr, rowVal, nzVal, (m, n), nnzA)
59+
finalizer(sparse_release_matrix_handle, dA)
60+
else
61+
dA = oneSparseMatrixCSC{$elty, $intty}(nothing, colPtr, rowVal, nzVal, (m, n), nnzA)
62+
end
63+
return dA
64+
end
65+
66+
67+
function oneSparseMatrixCSR(A::SparseMatrixCSC{$elty, $intty})
1968
m, n = size(A)
2069
At = SparseMatrixCSC(A |> transpose)
2170
rowPtr = oneVector{$intty}(At.colptr)
2271
colVal = oneVector{$intty}(At.rowval)
2372
nzVal = oneVector{$elty}(At.nzval)
24-
nnzA = length(At.nzval)
25-
queue = global_queue(context(nzVal), device())
26-
$fname(sycl_queue(queue), handle_ptr[], m, n, 'O', rowPtr, colVal, nzVal)
27-
dA = oneSparseMatrixCSR{$elty, $intty}(handle_ptr[], rowPtr, colVal, nzVal, (m,n), nnzA)
28-
finalizer(sparse_release_matrix_handle, dA)
29-
return dA
73+
return oneSparseMatrixCSR(rowPtr, colVal, nzVal, (m, n))
3074
end
3175

3276
function SparseMatrixCSC(A::oneSparseMatrixCSR{$elty, $intty})
33-
handle_ptr = Ref{matrix_handle_t}()
3477
At = SparseMatrixCSC(reverse(A.dims)..., Vector(A.rowPtr), Vector(A.colVal), Vector(A.nzVal))
3578
A_csc = SparseMatrixCSC(At |> transpose)
3679
return A_csc
3780
end
3881

3982
function oneSparseMatrixCSC(A::SparseMatrixCSC{$elty, $intty})
40-
handle_ptr = Ref{matrix_handle_t}()
41-
onemklXsparse_init_matrix_handle(handle_ptr)
4283
m, n = size(A)
4384
colPtr = oneVector{$intty}(A.colptr)
4485
rowVal = oneVector{$intty}(A.rowval)
4586
nzVal = oneVector{$elty}(A.nzval)
46-
nnzA = length(A.nzval)
47-
queue = global_queue(context(nzVal), device())
48-
$fname(sycl_queue(queue), handle_ptr[], n, m, 'O', colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ
49-
dA = oneSparseMatrixCSC{$elty, $intty}(handle_ptr[], colPtr, rowVal, nzVal, (m,n), nnzA)
50-
finalizer(sparse_release_matrix_handle, dA)
51-
return dA
87+
return oneSparseMatrixCSC(colPtr, rowVal, nzVal, (m, n))
5288
end
5389

5490
function SparseMatrixCSC(A::oneSparseMatrixCSC{$elty, $intty})
55-
handle_ptr = Ref{matrix_handle_t}()
5691
A_csc = SparseMatrixCSC(A.dims..., Vector(A.colPtr), Vector(A.rowVal), Vector(A.nzVal))
5792
return A_csc
5893
end
@@ -77,15 +112,18 @@ for (fname, elty, intty) in ((:onemklSsparse_set_coo_data , :Float32 , :Int3
77112
colInd = oneVector{$intty}(col)
78113
nzVal = oneVector{$elty}(val)
79114
nnzA = length(val)
80-
queue = global_queue(context(nzVal), device())
81-
$fname(sycl_queue(queue), handle_ptr[], m, n, nnzA, 'O', rowInd, colInd, nzVal)
82-
dA = oneSparseMatrixCOO{$elty, $intty}(handle_ptr[], rowInd, colInd, nzVal, (m,n), nnzA)
83-
finalizer(sparse_release_matrix_handle, dA)
115+
queue = global_queue(context(nzVal), device(nzVal))
116+
if m != 0 && n != 0
117+
$fname(sycl_queue(queue), handle_ptr[], m, n, nnzA, 'O', rowInd, colInd, nzVal)
118+
dA = oneSparseMatrixCOO{$elty, $intty}(handle_ptr[], rowInd, colInd, nzVal, (m,n), nnzA)
119+
finalizer(sparse_release_matrix_handle, dA)
120+
else
121+
dA = oneSparseMatrixCOO{$elty, $intty}(nothing, rowInd, colInd, nzVal, (m,n), nnzA)
122+
end
84123
return dA
85124
end
86125

87126
function SparseMatrixCSC(A::oneSparseMatrixCOO{$elty, $intty})
88-
handle_ptr = Ref{matrix_handle_t}()
89127
A = sparse(Vector(A.rowInd), Vector(A.colInd), Vector(A.nzVal), A.dims...)
90128
return A
91129
end
@@ -105,7 +143,7 @@ for SparseMatrix in (:oneSparseMatrixCSR, :oneSparseMatrixCOO)
105143
beta::Number,
106144
y::oneStridedVector{$elty})
107145

108-
queue = global_queue(context(x), device())
146+
queue = global_queue(context(x), device(x))
109147
$fname(sycl_queue(queue), trans, alpha, A.handle, x, beta, y)
110148
y
111149
end
@@ -140,8 +178,11 @@ for SparseMatrix in (:oneSparseMatrixCSC,)
140178
beta::Number,
141179
y::oneStridedVector{$elty})
142180

143-
queue = global_queue(context(x), device())
144-
$fname(sycl_queue(queue), flip_trans(trans), alpha, A.handle, x, beta, y)
181+
queue = global_queue(context(x), device(x))
182+
m, n = size(A)
183+
if m != 0 && n != 0
184+
$fname(sycl_queue(queue), flip_trans(trans), alpha, A.handle, x, beta, y)
185+
end
145186
y
146187
end
147188
end
@@ -173,7 +214,7 @@ for SparseMatrix in (:oneSparseMatrixCSC,)
173214
beta = conj(beta)
174215
end
175216

176-
queue = global_queue(context(x), device())
217+
queue = global_queue(context(x), device(x))
177218
$fname(sycl_queue(queue), flip_trans(trans), alpha, A.handle, x, beta, y)
178219

179220
if trans == 'C'
@@ -217,7 +258,7 @@ for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
217258
nrhs = size(B, 2)
218259
ldb = max(1,stride(B,2))
219260
ldc = max(1,stride(C,2))
220-
queue = global_queue(context(C), device())
261+
queue = global_queue(context(C), device(C))
221262
$fname(sycl_queue(queue), 'C', transa, transb, alpha, A.handle, B, nrhs, ldb, beta, C, ldc)
222263
C
223264
end
@@ -254,7 +295,7 @@ for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
254295
nrhs = size(B, 2)
255296
ldb = max(1,stride(B,2))
256297
ldc = max(1,stride(C,2))
257-
queue = global_queue(context(C), device())
298+
queue = global_queue(context(C), device(C))
258299
$fname(sycl_queue(queue), 'C', flip_trans(transa), transb, alpha, A.handle, B, nrhs, ldb, beta, C, ldc)
259300
C
260301
end
@@ -289,7 +330,7 @@ for (fname, elty) in (
289330
nrhs = size(B, 2)
290331
ldb = max(1, stride(B, 2))
291332
ldc = max(1, stride(C, 2))
292-
queue = global_queue(context(C), device())
333+
queue = global_queue(context(C), device(C))
293334

294335
# Use identity: conj(C_new) = conj(alpha) * S * conj(opB(B)) + conj(beta) * conj(C)
295336
# Prepare conj(C) in-place and conj(B) into a temporary if needed
@@ -359,7 +400,7 @@ for (fname, elty) in ((:onemklSsparse_symv, :Float32),
359400
beta::Number,
360401
y::oneStridedVector{$elty})
361402

362-
queue = global_queue(context(y), device())
403+
queue = global_queue(context(y), device(y))
363404
$fname(sycl_queue(queue), uplo, alpha, A.handle, x, beta, y)
364405
y
365406
end
@@ -379,7 +420,7 @@ for (fname, elty) in ((:onemklSsparse_symv, :Float32),
379420
beta::Number,
380421
y::oneStridedVector{$elty})
381422

382-
queue = global_queue(context(y), device())
423+
queue = global_queue(context(y), device(y))
383424
$fname(sycl_queue(queue), flip_uplo(uplo), alpha, A.handle, x, beta, y)
384425
y
385426
end
@@ -400,7 +441,7 @@ for (fname, elty) in ((:onemklSsparse_trmv, :Float32),
400441
beta::Number,
401442
y::oneStridedVector{$elty})
402443

403-
queue = global_queue(context(y), device())
444+
queue = global_queue(context(y), device(y))
404445
$fname(sycl_queue(queue), uplo, trans, diag, alpha, A.handle, x, beta, y)
405446
y
406447
end
@@ -442,7 +483,7 @@ for (fname, elty) in (
442483
"Convert to oneSparseMatrixCSR format instead."
443484
)
444485
)
445-
queue = global_queue(context(y), device())
486+
queue = global_queue(context(y), device(y))
446487
$fname(sycl_queue(queue), uplo, flip_trans(trans), diag, alpha, A.handle, x, beta, y)
447488
return y
448489
end
@@ -475,7 +516,7 @@ for (fname, elty) in ((:onemklSsparse_trsv, :Float32),
475516
x::oneStridedVector{$elty},
476517
y::oneStridedVector{$elty})
477518

478-
queue = global_queue(context(y), device())
519+
queue = global_queue(context(y), device(y))
479520
$fname(sycl_queue(queue), uplo, trans, diag, alpha, A.handle, x, y)
480521
y
481522
end
@@ -512,7 +553,7 @@ for (fname, elty) in (
512553
"Convert to oneSparseMatrixCSR format instead."
513554
)
514555
)
515-
queue = global_queue(context(y), device())
556+
queue = global_queue(context(y), device(y))
516557
onemklXsparse_optimize_trsv(sycl_queue(queue), uplo, flip_trans(trans), diag, A.handle)
517558
return A
518559
end
@@ -555,7 +596,7 @@ for (fname, elty) in ((:onemklSsparse_trsm, :Float32),
555596
nrhs = size(X, 2)
556597
ldx = max(1,stride(X,2))
557598
ldy = max(1,stride(Y,2))
558-
queue = global_queue(context(Y), device())
599+
queue = global_queue(context(Y), device(Y))
559600
$fname(sycl_queue(queue), 'C', transA, transX, uplo, diag, alpha, A.handle, X, nrhs, ldx, Y, ldy)
560601
Y
561602
end
@@ -614,7 +655,7 @@ for (fname, elty) in (
614655
nrhs = size(X, 2)
615656
ldx = max(1, stride(X, 2))
616657
ldy = max(1, stride(Y, 2))
617-
queue = global_queue(context(Y), device())
658+
queue = global_queue(context(Y), device(Y))
618659
$fname(sycl_queue(queue), 'C', flip_trans(transA), transX, uplo, diag, alpha, A.handle, X, nrhs, ldx, Y, ldy)
619660
return Y
620661
end

src/indexing.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,17 @@ function Base.findall(bools::oneArray{Bool})
2020
I = keytype(bools)
2121

2222
indices = cumsum(reshape(bools, prod(size(bools))))
23-
oneL0.synchronize()
2423

2524
n = isempty(indices) ? 0 : @allowscalar indices[end]
2625

2726
ys = oneArray{I}(undef, n)
2827

2928
if n > 0
30-
@oneapi items = length(bools) _ker!(ys, bools, indices)
29+
kernel = @oneapi launch=false _ker!(ys, bools, indices)
30+
group_size = launch_configuration(kernel)
31+
kernel(ys, bools, indices; items=group_size, groups=cld(length(bools), group_size))
3132
end
32-
oneL0.synchronize()
33-
unsafe_free!(indices)
33+
# unsafe_free!(indices)
3434

3535
return ys
3636
end

src/oneAPI.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ include("utils.jl")
6969
include("oneAPIKernels.jl")
7070
import .oneAPIKernels: oneAPIBackend
7171
include("accumulate.jl")
72+
include("sorting.jl")
7273
include("indexing.jl")
7374
export oneAPIBackend
7475

src/sorting.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Base.sort!(x::oneArray; kwargs...) = (AK.sort!(x; kwargs...); return x)
2+
Base.sortperm!(ix::oneArray, x::oneArray; kwargs...) = (AK.sortperm!(ix, x; kwargs...); return ix)
3+
Base.sortperm(x::oneArray; kwargs...) = sortperm!(oneArray(1:length(x)), x; kwargs...)

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1919
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2020
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2121
libigc_jll = "94295238-5935-5bd7-bb0f-b00942e9bdd5"
22+
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
2223
oneAPI_Support_jll = "b049733a-a71d-5ed3-8eba-7d323ac00b36"

test/indexing.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,18 @@ using oneAPI
1717
data = oneArray(collect(1:6))
1818
mask = oneArray(Bool[true, false, true, false, false, true])
1919
@test Array(data[mask]) == collect(1:6)[findall(Bool[true, false, true, false, false, true])]
20+
21+
# Test with array larger than 1024 to trigger multiple groups
22+
large_size = 2048
23+
large_mask = oneArray(rand(Bool, large_size))
24+
large_result_gpu = Array(findall(large_mask))
25+
large_result_cpu = findall(Array(large_mask))
26+
@test large_result_gpu == large_result_cpu
27+
28+
# Test with even larger array to ensure robustness
29+
very_large_size = 5000
30+
very_large_mask = oneArray(fill(true, very_large_size)) # all true for predictable result
31+
very_large_result_gpu = Array(findall(very_large_mask))
32+
very_large_result_cpu = findall(fill(true, very_large_size))
33+
@test very_large_result_gpu == very_large_result_cpu
2034
end

test/onemkl.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,10 @@ end
10901090
B = oneSparseMatrixCSR(A)
10911091
A2 = SparseMatrixCSC(B)
10921092
@test A == A2
1093+
C = oneSparseMatrixCSR(B.rowPtr, B.colVal, B.nzVal, size(B))
1094+
A3 = SparseMatrixCSC(C)
1095+
@test A == A3
1096+
D = oneSparseMatrixCSR(oneVector(S[]), oneVector(S[]), oneVector(T[]), (0, 0)) # empty matrix
10931097
end
10941098
end
10951099

@@ -1101,6 +1105,10 @@ end
11011105
B = oneSparseMatrixCSC(A)
11021106
A2 = SparseMatrixCSC(B)
11031107
@test A == A2
1108+
C = oneSparseMatrixCSC(A.colptr |> oneVector, A.rowval |> oneVector, A.nzval |> oneVector, size(A))
1109+
A3 = SparseMatrixCSC(C)
1110+
@test A == A3
1111+
D = oneSparseMatrixCSC(oneVector(S[]), oneVector(S[]), oneVector(T[]), (0, 0)) # empty matrix
11041112
end
11051113
end
11061114

test/sorting.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using Test
2+
using oneAPI
3+
4+
@testset "sorting" begin
5+
data = oneArray([3, 1, 4, 1, 5])
6+
sort!(data)
7+
@test Array(data) == [1, 1, 3, 4, 5]
8+
9+
data_rev = oneArray([3, 1, 4, 1, 5])
10+
sort!(data_rev, rev = true)
11+
@test Array(data_rev) == [5, 4, 3, 1, 1]
12+
data = oneArray([3, 1, 4, 1, 5])
13+
@test Array(sortperm(data)) == sortperm([3, 1, 4, 1, 5])
14+
15+
data_rev = oneArray([3, 1, 4, 1, 5])
16+
@test Array(sortperm(data_rev, rev = true)) == sortperm([3, 1, 4, 1, 5], rev = true)
17+
end

0 commit comments

Comments
 (0)