11function sparse_release_matrix_handle (A:: oneAbstractSparseMatrix )
2- queue = global_queue (context (A. nzVal), device (A. nzVal))
3- handle_ptr = Ref {matrix_handle_t} (A. handle)
4- onemklXsparse_release_matrix_handle (sycl_queue (queue), handle_ptr)
2+ if A. handle != = nothing
3+ try
4+ queue = global_queue (context (A. nzVal), device (A. nzVal))
5+ handle_ptr = Ref {matrix_handle_t} (A. handle)
6+ onemklXsparse_release_matrix_handle (sycl_queue (queue), handle_ptr)
7+ # Only synchronize after successful release to ensure completion
8+ synchronize (queue)
9+ catch err
10+ # Don't let finalizer errors crash the program
11+ @warn " Error releasing sparse matrix handle" exception= err
12+ end
13+ end
514end
615
716for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int32 ),
@@ -13,46 +22,72 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
1322 (:onemklZsparse_set_csr_data , :ComplexF64 , :Int32 ),
1423 (:onemklZsparse_set_csr_data_64 , :ComplexF64 , :Int64 ))
1524 @eval begin
16- function oneSparseMatrixCSR (A:: SparseMatrixCSC{$elty, $intty} )
25+
26+ function oneSparseMatrixCSR (
27+ rowPtr:: oneVector{$intty} , colVal:: oneVector{$intty} ,
28+ nzVal:: oneVector{$elty} , dims:: NTuple{2, Int}
29+ )
1730 handle_ptr = Ref {matrix_handle_t} ()
1831 onemklXsparse_init_matrix_handle (handle_ptr)
32+ m, n = dims
33+ nnzA = length (nzVal)
34+ queue = global_queue (context (nzVal), device (nzVal))
35+ # Don't update handle if matrix is empty
36+ if m != 0 && n != 0
37+ $ fname (sycl_queue (queue), handle_ptr[], m, n, ' O' , rowPtr, colVal, nzVal)
38+ dA = oneSparseMatrixCSR {$elty, $intty} (handle_ptr[], rowPtr, colVal, nzVal, (m, n), nnzA)
39+ finalizer (sparse_release_matrix_handle, dA)
40+ else
41+ dA = oneSparseMatrixCSR {$elty, $intty} (nothing , rowPtr, colVal, nzVal, (m, n), nnzA)
42+ end
43+ return dA
44+ end
45+
46+ function oneSparseMatrixCSC (
47+ colPtr:: oneVector{$intty} , rowVal:: oneVector{$intty} ,
48+ nzVal:: oneVector{$elty} , dims:: NTuple{2, Int}
49+ )
50+ queue = global_queue (context (nzVal), device (nzVal))
51+ handle_ptr = Ref {matrix_handle_t} ()
52+ onemklXsparse_init_matrix_handle (handle_ptr)
53+ m, n = dims
54+ nnzA = length (nzVal)
55+ # Don't update handle if matrix is empty
56+ if m != 0 && n != 0
57+ $ fname (sycl_queue (queue), handle_ptr[], n, m, ' O' , colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ
58+ dA = oneSparseMatrixCSC {$elty, $intty} (handle_ptr[], colPtr, rowVal, nzVal, (m, n), nnzA)
59+ finalizer (sparse_release_matrix_handle, dA)
60+ else
61+ dA = oneSparseMatrixCSC {$elty, $intty} (nothing , colPtr, rowVal, nzVal, (m, n), nnzA)
62+ end
63+ return dA
64+ end
65+
66+
67+ function oneSparseMatrixCSR (A:: SparseMatrixCSC{$elty, $intty} )
1968 m, n = size (A)
2069 At = SparseMatrixCSC (A |> transpose)
2170 rowPtr = oneVector {$intty} (At. colptr)
2271 colVal = oneVector {$intty} (At. rowval)
2372 nzVal = oneVector {$elty} (At. nzval)
24- nnzA = length (At. nzval)
25- queue = global_queue (context (nzVal), device ())
26- $ fname (sycl_queue (queue), handle_ptr[], m, n, ' O' , rowPtr, colVal, nzVal)
27- dA = oneSparseMatrixCSR {$elty, $intty} (handle_ptr[], rowPtr, colVal, nzVal, (m,n), nnzA)
28- finalizer (sparse_release_matrix_handle, dA)
29- return dA
73+ return oneSparseMatrixCSR (rowPtr, colVal, nzVal, (m, n))
3074 end
3175
3276 function SparseMatrixCSC (A:: oneSparseMatrixCSR{$elty, $intty} )
33- handle_ptr = Ref {matrix_handle_t} ()
3477 At = SparseMatrixCSC (reverse (A. dims)... , Vector (A. rowPtr), Vector (A. colVal), Vector (A. nzVal))
3578 A_csc = SparseMatrixCSC (At |> transpose)
3679 return A_csc
3780 end
3881
3982 function oneSparseMatrixCSC (A:: SparseMatrixCSC{$elty, $intty} )
40- handle_ptr = Ref {matrix_handle_t} ()
41- onemklXsparse_init_matrix_handle (handle_ptr)
4283 m, n = size (A)
4384 colPtr = oneVector {$intty} (A. colptr)
4485 rowVal = oneVector {$intty} (A. rowval)
4586 nzVal = oneVector {$elty} (A. nzval)
46- nnzA = length (A. nzval)
47- queue = global_queue (context (nzVal), device ())
48- $ fname (sycl_queue (queue), handle_ptr[], n, m, ' O' , colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ
49- dA = oneSparseMatrixCSC {$elty, $intty} (handle_ptr[], colPtr, rowVal, nzVal, (m,n), nnzA)
50- finalizer (sparse_release_matrix_handle, dA)
51- return dA
87+ return oneSparseMatrixCSC (colPtr, rowVal, nzVal, (m, n))
5288 end
5389
5490 function SparseMatrixCSC (A:: oneSparseMatrixCSC{$elty, $intty} )
55- handle_ptr = Ref {matrix_handle_t} ()
5691 A_csc = SparseMatrixCSC (A. dims... , Vector (A. colPtr), Vector (A. rowVal), Vector (A. nzVal))
5792 return A_csc
5893 end
@@ -77,15 +112,18 @@ for (fname, elty, intty) in ((:onemklSsparse_set_coo_data , :Float32 , :Int3
77112 colInd = oneVector {$intty} (col)
78113 nzVal = oneVector {$elty} (val)
79114 nnzA = length (val)
80- queue = global_queue (context (nzVal), device ())
81- $ fname (sycl_queue (queue), handle_ptr[], m, n, nnzA, ' O' , rowInd, colInd, nzVal)
82- dA = oneSparseMatrixCOO {$elty, $intty} (handle_ptr[], rowInd, colInd, nzVal, (m,n), nnzA)
83- finalizer (sparse_release_matrix_handle, dA)
115+ queue = global_queue (context (nzVal), device (nzVal))
116+ if m != 0 && n != 0
117+ $ fname (sycl_queue (queue), handle_ptr[], m, n, nnzA, ' O' , rowInd, colInd, nzVal)
118+ dA = oneSparseMatrixCOO {$elty, $intty} (handle_ptr[], rowInd, colInd, nzVal, (m,n), nnzA)
119+ finalizer (sparse_release_matrix_handle, dA)
120+ else
121+ dA = oneSparseMatrixCOO {$elty, $intty} (nothing , rowInd, colInd, nzVal, (m,n), nnzA)
122+ end
84123 return dA
85124 end
86125
87126 function SparseMatrixCSC (A:: oneSparseMatrixCOO{$elty, $intty} )
88- handle_ptr = Ref {matrix_handle_t} ()
89127 A = sparse (Vector (A. rowInd), Vector (A. colInd), Vector (A. nzVal), A. dims... )
90128 return A
91129 end
@@ -105,7 +143,7 @@ for SparseMatrix in (:oneSparseMatrixCSR, :oneSparseMatrixCOO)
105143 beta:: Number ,
106144 y:: oneStridedVector{$elty} )
107145
108- queue = global_queue (context (x), device ())
146+ queue = global_queue (context (x), device (x ))
109147 $ fname (sycl_queue (queue), trans, alpha, A. handle, x, beta, y)
110148 y
111149 end
@@ -140,8 +178,11 @@ for SparseMatrix in (:oneSparseMatrixCSC,)
140178 beta:: Number ,
141179 y:: oneStridedVector{$elty} )
142180
143- queue = global_queue (context (x), device ())
144- $ fname (sycl_queue (queue), flip_trans (trans), alpha, A. handle, x, beta, y)
181+ queue = global_queue (context (x), device (x))
182+ m, n = size (A)
183+ if m != 0 && n != 0
184+ $ fname (sycl_queue (queue), flip_trans (trans), alpha, A. handle, x, beta, y)
185+ end
145186 y
146187 end
147188 end
@@ -173,7 +214,7 @@ for SparseMatrix in (:oneSparseMatrixCSC,)
173214 beta = conj (beta)
174215 end
175216
176- queue = global_queue (context (x), device ())
217+ queue = global_queue (context (x), device (x ))
177218 $ fname (sycl_queue (queue), flip_trans (trans), alpha, A. handle, x, beta, y)
178219
179220 if trans == ' C'
@@ -217,7 +258,7 @@ for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
217258 nrhs = size (B, 2 )
218259 ldb = max (1 ,stride (B,2 ))
219260 ldc = max (1 ,stride (C,2 ))
220- queue = global_queue (context (C), device ())
261+ queue = global_queue (context (C), device (C ))
221262 $ fname (sycl_queue (queue), ' C' , transa, transb, alpha, A. handle, B, nrhs, ldb, beta, C, ldc)
222263 C
223264 end
@@ -254,7 +295,7 @@ for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
254295 nrhs = size (B, 2 )
255296 ldb = max (1 ,stride (B,2 ))
256297 ldc = max (1 ,stride (C,2 ))
257- queue = global_queue (context (C), device ())
298+ queue = global_queue (context (C), device (C ))
258299 $ fname (sycl_queue (queue), ' C' , flip_trans (transa), transb, alpha, A. handle, B, nrhs, ldb, beta, C, ldc)
259300 C
260301 end
@@ -289,7 +330,7 @@ for (fname, elty) in (
289330 nrhs = size (B, 2 )
290331 ldb = max (1 , stride (B, 2 ))
291332 ldc = max (1 , stride (C, 2 ))
292- queue = global_queue (context (C), device ())
333+ queue = global_queue (context (C), device (C ))
293334
294335 # Use identity: conj(C_new) = conj(alpha) * S * conj(opB(B)) + conj(beta) * conj(C)
295336 # Prepare conj(C) in-place and conj(B) into a temporary if needed
@@ -359,7 +400,7 @@ for (fname, elty) in ((:onemklSsparse_symv, :Float32),
359400 beta:: Number ,
360401 y:: oneStridedVector{$elty} )
361402
362- queue = global_queue (context (y), device ())
403+ queue = global_queue (context (y), device (y ))
363404 $ fname (sycl_queue (queue), uplo, alpha, A. handle, x, beta, y)
364405 y
365406 end
@@ -379,7 +420,7 @@ for (fname, elty) in ((:onemklSsparse_symv, :Float32),
379420 beta:: Number ,
380421 y:: oneStridedVector{$elty} )
381422
382- queue = global_queue (context (y), device ())
423+ queue = global_queue (context (y), device (y ))
383424 $ fname (sycl_queue (queue), flip_uplo (uplo), alpha, A. handle, x, beta, y)
384425 y
385426 end
@@ -400,7 +441,7 @@ for (fname, elty) in ((:onemklSsparse_trmv, :Float32),
400441 beta:: Number ,
401442 y:: oneStridedVector{$elty} )
402443
403- queue = global_queue (context (y), device ())
444+ queue = global_queue (context (y), device (y ))
404445 $ fname (sycl_queue (queue), uplo, trans, diag, alpha, A. handle, x, beta, y)
405446 y
406447 end
@@ -442,7 +483,7 @@ for (fname, elty) in (
442483 " Convert to oneSparseMatrixCSR format instead."
443484 )
444485 )
445- queue = global_queue (context (y), device ())
486+ queue = global_queue (context (y), device (y ))
446487 $ fname (sycl_queue (queue), uplo, flip_trans (trans), diag, alpha, A. handle, x, beta, y)
447488 return y
448489 end
@@ -475,7 +516,7 @@ for (fname, elty) in ((:onemklSsparse_trsv, :Float32),
475516 x:: oneStridedVector{$elty} ,
476517 y:: oneStridedVector{$elty} )
477518
478- queue = global_queue (context (y), device ())
519+ queue = global_queue (context (y), device (y ))
479520 $ fname (sycl_queue (queue), uplo, trans, diag, alpha, A. handle, x, y)
480521 y
481522 end
@@ -512,7 +553,7 @@ for (fname, elty) in (
512553 " Convert to oneSparseMatrixCSR format instead."
513554 )
514555 )
515- queue = global_queue (context (y), device ())
556+ queue = global_queue (context (y), device (y ))
516557 onemklXsparse_optimize_trsv (sycl_queue (queue), uplo, flip_trans (trans), diag, A. handle)
517558 return A
518559 end
@@ -555,7 +596,7 @@ for (fname, elty) in ((:onemklSsparse_trsm, :Float32),
555596 nrhs = size (X, 2 )
556597 ldx = max (1 ,stride (X,2 ))
557598 ldy = max (1 ,stride (Y,2 ))
558- queue = global_queue (context (Y), device ())
599+ queue = global_queue (context (Y), device (Y ))
559600 $ fname (sycl_queue (queue), ' C' , transA, transX, uplo, diag, alpha, A. handle, X, nrhs, ldx, Y, ldy)
560601 Y
561602 end
@@ -614,7 +655,7 @@ for (fname, elty) in (
614655 nrhs = size (X, 2 )
615656 ldx = max (1 , stride (X, 2 ))
616657 ldy = max (1 , stride (Y, 2 ))
617- queue = global_queue (context (Y), device ())
658+ queue = global_queue (context (Y), device (Y ))
618659 $ fname (sycl_queue (queue), ' C' , flip_trans (transA), transX, uplo, diag, alpha, A. handle, X, nrhs, ldx, Y, ldy)
619660 return Y
620661 end
0 commit comments