1+ //! Implement linear solver using LU decomposition
2+ //! for tridiagonal matrix
3+
4+ use lapacke;
5+ use ndarray:: * ;
6+ use num_traits:: Zero ;
7+
8+ use super :: NormType ;
9+ use super :: { into_result, Pivot , Transpose } ;
10+
11+ use crate :: error:: * ;
12+ use crate :: layout:: MatrixLayout ;
13+ use crate :: tridiagonal:: { TriDiagonal , LUFactorizedTriDiagonal } ;
14+ use crate :: types:: * ;
15+
16+ /// Wraps `*gttrf`, `*gtcon` and `*gttrs`
17+ pub trait TriDiagonal_ : Scalar + Sized {
18+ /// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
19+ /// partial pivoting with row interchanges.
20+ unsafe fn lu_tridiagonal ( a : & mut TriDiagonal < Self > ) -> Result < ( Array1 < Self > , Pivot ) > ;
21+ /// Estimates the the reciprocal of the condition number of the tridiagonal matrix in 1-norm.
22+ unsafe fn rcond_tridiagonal ( lu : & LUFactorizedTriDiagonal < Self > ) -> Result < Self :: Real > ;
23+ unsafe fn solve_tridiagonal (
24+ lu : & LUFactorizedTriDiagonal < Self > ,
25+ bl : MatrixLayout ,
26+ t : Transpose ,
27+ b : & mut [ Self ] ) -> Result < ( ) > ;
28+ }
29+
30+ macro_rules! impl_tridiagonal {
31+ ( $scalar: ty, $gttrf: path, $gtcon: path, $gttrs: path) => {
32+ impl TriDiagonal_ for $scalar {
33+ unsafe fn lu_tridiagonal( a: & mut TriDiagonal <Self >) -> Result <( Array1 <Self >, Pivot ) > {
34+ let ( n, _) = a. l. size( ) ;
35+ let dl = a. dl. as_slice_mut( ) . unwrap( ) ;
36+ let d = a. d. as_slice_mut( ) . unwrap( ) ;
37+ let du = a. du. as_slice_mut( ) . unwrap( ) ;
38+ let mut du2 = vec![ Zero :: zero( ) ; ( n-2 ) as usize ] ;
39+ let mut ipiv = vec![ 0 ; n as usize ] ;
40+ let info = $gttrf( n, dl, d, du, & mut du2, & mut ipiv) ;
41+ into_result( info, ( arr1( & du2) , ipiv) )
42+ }
43+
44+ unsafe fn rcond_tridiagonal( lu: & LUFactorizedTriDiagonal <Self >) -> Result <Self :: Real > {
45+ let ( n, _) = lu. a. l. size( ) ;
46+ let dl = lu. a. dl. as_slice( ) . unwrap( ) ;
47+ let d = lu. a. d. as_slice( ) . unwrap( ) ;
48+ let du = lu. a. du. as_slice( ) . unwrap( ) ;
49+ let du2 = lu. du2. as_slice( ) . unwrap( ) ;
50+ let ipiv = & lu. ipiv;
51+ let anorm = lu. a. n1;
52+ let mut rcond = Self :: Real :: zero( ) ;
53+ let info = $gtcon(
54+ NormType :: One as u8 ,
55+ n,
56+ dl,
57+ d,
58+ du,
59+ du2,
60+ ipiv,
61+ anorm,
62+ & mut rcond,
63+ ) ;
64+ into_result( info, rcond)
65+ }
66+
67+ unsafe fn solve_tridiagonal(
68+ lu: & LUFactorizedTriDiagonal <Self >,
69+ bl: MatrixLayout ,
70+ t: Transpose ,
71+ b: & mut [ Self ]
72+ ) -> Result <( ) > {
73+ let ( n, _) = lu. a. l. size( ) ;
74+ let ( _, nrhs) = bl. size( ) ;
75+ let dl = lu. a. dl. as_slice( ) . unwrap( ) ;
76+ let d = lu. a. d. as_slice( ) . unwrap( ) ;
77+ let du = lu. a. du. as_slice( ) . unwrap( ) ;
78+ let du2 = lu. du2. as_slice( ) . unwrap( ) ;
79+ let ipiv = & lu. ipiv;
80+ let ldb = bl. lda( ) ;
81+ let info = $gttrs(
82+ lu. a. l. lapacke_layout( ) ,
83+ t as u8 ,
84+ n,
85+ nrhs,
86+ dl,
87+ d,
88+ du,
89+ du2,
90+ ipiv,
91+ b,
92+ ldb,
93+ ) ;
94+ into_result( info, ( ) )
95+ }
96+ }
97+ } ;
98+ } // impl_tridiagonal!
99+
100+ impl_tridiagonal ! ( f64 , lapacke:: dgttrf, lapacke:: dgtcon, lapacke:: dgttrs) ;
101+ impl_tridiagonal ! ( f32 , lapacke:: sgttrf, lapacke:: sgtcon, lapacke:: sgttrs) ;
102+ impl_tridiagonal ! ( c64, lapacke:: zgttrf, lapacke:: zgtcon, lapacke:: zgttrs) ;
103+ impl_tridiagonal ! ( c32, lapacke:: cgttrf, lapacke:: cgtcon, lapacke:: cgttrs) ;
0 commit comments