lax/tridiagonal/
lu.rs

1use crate::*;
2use cauchy::*;
3use num_traits::Zero;
4
5/// Represents the LU factorization of a tridiagonal matrix `A` as `A = P*L*U`.
6#[derive(Clone, PartialEq)]
7pub struct LUFactorizedTridiagonal<A: Scalar> {
8    /// A tridiagonal matrix which consists of
9    /// - l : layout of raw matrix
10    /// - dl: (n-1) multipliers that define the matrix L.
11    /// - d : (n) diagonal elements of the upper triangular matrix U.
12    /// - du: (n-1) elements of the first super-diagonal of U.
13    pub a: Tridiagonal<A>,
14    /// (n-2) elements of the second super-diagonal of U.
15    pub du2: Vec<A>,
16    /// The pivot indices that define the permutation matrix `P`.
17    pub ipiv: Pivot,
18
19    pub a_opnorm_one: A::Real,
20}
21
22impl<A: Scalar> Tridiagonal<A> {
23    fn opnorm_one(&self) -> A::Real {
24        let mut col_sum: Vec<A::Real> = self.d.iter().map(|val| val.abs()).collect();
25        for i in 0..col_sum.len() {
26            if i < self.dl.len() {
27                col_sum[i] += self.dl[i].abs();
28            }
29            if i > 0 {
30                col_sum[i] += self.du[i - 1].abs();
31            }
32        }
33        let mut max = A::Real::zero();
34        for &val in &col_sum {
35            if max < val {
36                max = val;
37            }
38        }
39        max
40    }
41}
42
43pub struct LuTridiagonalWork<T: Scalar> {
44    pub layout: MatrixLayout,
45    pub du2: Vec<MaybeUninit<T>>,
46    pub ipiv: Vec<MaybeUninit<i32>>,
47}
48
49pub trait LuTridiagonalWorkImpl {
50    type Elem: Scalar;
51    fn new(layout: MatrixLayout) -> Self;
52    fn eval(self, a: Tridiagonal<Self::Elem>) -> Result<LUFactorizedTridiagonal<Self::Elem>>;
53}
54
55macro_rules! impl_lu_tridiagonal_work {
56    ($s:ty, $trf:path) => {
57        impl LuTridiagonalWorkImpl for LuTridiagonalWork<$s> {
58            type Elem = $s;
59
60            fn new(layout: MatrixLayout) -> Self {
61                let (n, _) = layout.size();
62                let du2 = vec_uninit((n - 2) as usize);
63                let ipiv = vec_uninit(n as usize);
64                LuTridiagonalWork { layout, du2, ipiv }
65            }
66
67            fn eval(
68                mut self,
69                mut a: Tridiagonal<Self::Elem>,
70            ) -> Result<LUFactorizedTridiagonal<Self::Elem>> {
71                let (n, _) = self.layout.size();
72                // We have to calc one-norm before LU factorization
73                let a_opnorm_one = a.opnorm_one();
74                let mut info = 0;
75                unsafe {
76                    $trf(
77                        &n,
78                        AsPtr::as_mut_ptr(&mut a.dl),
79                        AsPtr::as_mut_ptr(&mut a.d),
80                        AsPtr::as_mut_ptr(&mut a.du),
81                        AsPtr::as_mut_ptr(&mut self.du2),
82                        AsPtr::as_mut_ptr(&mut self.ipiv),
83                        &mut info,
84                    )
85                };
86                info.as_lapack_result()?;
87                Ok(LUFactorizedTridiagonal {
88                    a,
89                    du2: unsafe { self.du2.assume_init() },
90                    ipiv: unsafe { self.ipiv.assume_init() },
91                    a_opnorm_one,
92                })
93            }
94        }
95    };
96}
97
98impl_lu_tridiagonal_work!(c64, lapack_sys::zgttrf_);
99impl_lu_tridiagonal_work!(c32, lapack_sys::cgttrf_);
100impl_lu_tridiagonal_work!(f64, lapack_sys::dgttrf_);
101impl_lu_tridiagonal_work!(f32, lapack_sys::sgttrf_);