22//! &
33//! Methods for tridiagonal matrices
44
5+ use std:: ops:: { Index , IndexMut } ;
6+
57use cauchy:: Scalar ;
68use ndarray:: * ;
79use num_traits:: One ;
810
9- use crate :: opnorm:: OperationNorm ;
10-
1111use super :: convert:: * ;
1212use super :: error:: * ;
1313use super :: lapack:: * ;
1414use super :: layout:: * ;
1515
1616/// Represents a tridiagonal matrix as 3 one-dimensional vectors.
17- /// This struct also holds the layout and 1-norm of the raw matrix
18- /// for some methods (eg. rcond_tridiagonal()).
19- #[ derive( Clone ) ]
17+ /// This struct also holds the layout of the raw matrix.
18+ #[ derive( Clone , PartialEq ) ]
2019pub struct TriDiagonal < A : Scalar > {
2120 /// layout of raw matrix
2221 pub l : MatrixLayout ,
23- /// the one norm of raw matrix
24- pub n1 : <A as Scalar >:: Real ,
2522 /// (n-1) sub-diagonal elements of matrix.
2623 pub dl : Array1 < A > ,
2724 /// (n) diagonal elements of matrix.
@@ -30,10 +27,73 @@ pub struct TriDiagonal<A: Scalar> {
3027 pub du : Array1 < A > ,
3128}
3229
30+ pub trait TridiagIndex {
31+ fn to_tuple ( & self ) -> ( i32 , i32 ) ;
32+ }
33+ impl TridiagIndex for [ Ix ; 2 ] {
34+ fn to_tuple ( & self ) -> ( i32 , i32 ) {
35+ ( self [ 0 ] as i32 , self [ 1 ] as i32 )
36+ }
37+ }
38+
39+ fn debug_bounds_check_tridiag ( n : i32 , row : i32 , col : i32 ) {
40+ if std:: cmp:: max ( row, col) >= n {
41+ panic ! (
42+ "ndarray: index {:?} is out of bounds for array of shape {}" ,
43+ [ row, col] ,
44+ n
45+ ) ;
46+ }
47+ }
48+
49+ impl < A , I > Index < I > for TriDiagonal < A >
50+ where
51+ A : Scalar ,
52+ I : TridiagIndex ,
53+ {
54+ type Output = A ;
55+ #[ inline]
56+ fn index ( & self , index : I ) -> & A {
57+ let ( n, _) = self . l . size ( ) ;
58+ let ( row, col) = index. to_tuple ( ) ;
59+ debug_bounds_check_tridiag ( n, row, col) ;
60+ match row - col {
61+ 0 => & self . d [ row as usize ] ,
62+ 1 => & self . dl [ col as usize ] ,
63+ -1 => & self . du [ row as usize ] ,
64+ _ => panic ! (
65+ "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element" ,
66+ [ row, col]
67+ ) ,
68+ }
69+ }
70+ }
71+
72+ impl < A , I > IndexMut < I > for TriDiagonal < A >
73+ where
74+ A : Scalar ,
75+ I : TridiagIndex ,
76+ {
77+ #[ inline]
78+ fn index_mut ( & mut self , index : I ) -> & mut A {
79+ let ( n, _) = self . l . size ( ) ;
80+ let ( row, col) = index. to_tuple ( ) ;
81+ debug_bounds_check_tridiag ( n, row, col) ;
82+ match row - col {
83+ 0 => & mut self . d [ row as usize ] ,
84+ 1 => & mut self . dl [ col as usize ] ,
85+ -1 => & mut self . du [ row as usize ] ,
86+ _ => panic ! (
87+ "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element" ,
88+ [ row, col]
89+ ) ,
90+ }
91+ }
92+ }
93+
3394/// An interface for making a TriDiagonal struct.
3495pub trait ToTriDiagonal < A : Scalar > {
3596 /// Extract tridiagonal elements and layout of the raw matrix.
36- /// And also calculate 1-norm.
3797 ///
3898 /// If the raw matrix has some non-tridiagonal elements,
3999 /// they will be ignored.
@@ -53,12 +113,11 @@ where
53113 if n < 2 {
54114 panic ! ( "Cannot make a tridiagonal matrix of shape=(1, 1)!" ) ;
55115 }
56- let n1 = self . opnorm_one ( ) ?;
57116
58117 let dl = self . slice ( s ! [ 1 ..n, 0 ..n - 1 ] ) . diag ( ) . to_owned ( ) ;
59118 let d = self . diag ( ) . to_owned ( ) ;
60119 let du = self . slice ( s ! [ 0 ..n - 1 , 1 ..n] ) . diag ( ) . to_owned ( ) ;
61- Ok ( TriDiagonal { l, n1 , dl, d, du } )
120+ Ok ( TriDiagonal { l, dl, d, du } )
62121 }
63122}
64123
@@ -130,13 +189,14 @@ pub trait SolveTriDiagonalInplace<A: Scalar, D: Dimension> {
130189pub struct LUFactorizedTriDiagonal < A : Scalar > {
131190 /// A tridiagonal matrix which consists of
132191 /// - l : layout of raw matrix
133- /// - n1: the one norm of raw matrix
134192 /// - dl: (n-1) multipliers that define the matrix L.
135193 /// - d : (n) diagonal elements of the upper triangular matrix U.
136194 /// - du: (n-1) elements of the first super-diagonal of U.
137195 pub a : TriDiagonal < A > ,
138196 /// (n-2) elements of the second super-diagonal of U.
139197 pub du2 : Array1 < A > ,
198+ /// 1-norm of raw matrix (used in .rcond_tridiagonal()).
199+ pub anom : A :: Real ,
140200 /// The pivot indices that define the permutation matrix `P`.
141201 pub ipiv : Pivot ,
142202}
@@ -598,10 +658,11 @@ where
598658 A : Scalar + Lapack ,
599659{
600660 fn factorize_tridiagonal_into ( mut self ) -> Result < LUFactorizedTriDiagonal < A > > {
601- let ( du2, ipiv) = unsafe { A :: lu_tridiagonal ( & mut self ) ? } ;
661+ let ( du2, anom , ipiv) = unsafe { A :: lu_tridiagonal ( & mut self ) ? } ;
602662 Ok ( LUFactorizedTriDiagonal {
603663 a : self ,
604664 du2 : du2,
665+ anom : anom,
605666 ipiv : ipiv,
606667 } )
607668 }
@@ -613,8 +674,8 @@ where
613674{
614675 fn factorize_tridiagonal ( & self ) -> Result < LUFactorizedTriDiagonal < A > > {
615676 let mut a = self . clone ( ) ;
616- let ( du2, ipiv) = unsafe { A :: lu_tridiagonal ( & mut a) ? } ;
617- Ok ( LUFactorizedTriDiagonal { a, du2, ipiv } )
677+ let ( du2, anom , ipiv) = unsafe { A :: lu_tridiagonal ( & mut a) ? } ;
678+ Ok ( LUFactorizedTriDiagonal { a, du2, anom , ipiv } )
618679 }
619680}
620681
@@ -625,8 +686,8 @@ where
625686{
626687 fn factorize_tridiagonal ( & self ) -> Result < LUFactorizedTriDiagonal < A > > {
627688 let mut a = self . to_tridiagonal ( ) ?;
628- let ( du2, ipiv) = unsafe { A :: lu_tridiagonal ( & mut a) ? } ;
629- Ok ( LUFactorizedTriDiagonal { a, du2, ipiv } )
689+ let ( du2, anom , ipiv) = unsafe { A :: lu_tridiagonal ( & mut a) ? } ;
690+ Ok ( LUFactorizedTriDiagonal { a, du2, anom , ipiv } )
630691 }
631692}
632693
0 commit comments