1use crate::*;
2use cauchy::*;
3use num_traits::Zero;
4
5#[derive(Clone, PartialEq)]
7pub struct LUFactorizedTridiagonal<A: Scalar> {
8 pub a: Tridiagonal<A>,
14 pub du2: Vec<A>,
16 pub ipiv: Pivot,
18
19 pub a_opnorm_one: A::Real,
20}
21
22impl<A: Scalar> Tridiagonal<A> {
23 fn opnorm_one(&self) -> A::Real {
24 let mut col_sum: Vec<A::Real> = self.d.iter().map(|val| val.abs()).collect();
25 for i in 0..col_sum.len() {
26 if i < self.dl.len() {
27 col_sum[i] += self.dl[i].abs();
28 }
29 if i > 0 {
30 col_sum[i] += self.du[i - 1].abs();
31 }
32 }
33 let mut max = A::Real::zero();
34 for &val in &col_sum {
35 if max < val {
36 max = val;
37 }
38 }
39 max
40 }
41}
42
43pub struct LuTridiagonalWork<T: Scalar> {
44 pub layout: MatrixLayout,
45 pub du2: Vec<MaybeUninit<T>>,
46 pub ipiv: Vec<MaybeUninit<i32>>,
47}
48
49pub trait LuTridiagonalWorkImpl {
50 type Elem: Scalar;
51 fn new(layout: MatrixLayout) -> Self;
52 fn eval(self, a: Tridiagonal<Self::Elem>) -> Result<LUFactorizedTridiagonal<Self::Elem>>;
53}
54
55macro_rules! impl_lu_tridiagonal_work {
56 ($s:ty, $trf:path) => {
57 impl LuTridiagonalWorkImpl for LuTridiagonalWork<$s> {
58 type Elem = $s;
59
60 fn new(layout: MatrixLayout) -> Self {
61 let (n, _) = layout.size();
62 let du2 = vec_uninit((n - 2) as usize);
63 let ipiv = vec_uninit(n as usize);
64 LuTridiagonalWork { layout, du2, ipiv }
65 }
66
67 fn eval(
68 mut self,
69 mut a: Tridiagonal<Self::Elem>,
70 ) -> Result<LUFactorizedTridiagonal<Self::Elem>> {
71 let (n, _) = self.layout.size();
72 let a_opnorm_one = a.opnorm_one();
74 let mut info = 0;
75 unsafe {
76 $trf(
77 &n,
78 AsPtr::as_mut_ptr(&mut a.dl),
79 AsPtr::as_mut_ptr(&mut a.d),
80 AsPtr::as_mut_ptr(&mut a.du),
81 AsPtr::as_mut_ptr(&mut self.du2),
82 AsPtr::as_mut_ptr(&mut self.ipiv),
83 &mut info,
84 )
85 };
86 info.as_lapack_result()?;
87 Ok(LUFactorizedTridiagonal {
88 a,
89 du2: unsafe { self.du2.assume_init() },
90 ipiv: unsafe { self.ipiv.assume_init() },
91 a_opnorm_one,
92 })
93 }
94 }
95 };
96}
97
98impl_lu_tridiagonal_work!(c64, lapack_sys::zgttrf_);
99impl_lu_tridiagonal_work!(c32, lapack_sys::cgttrf_);
100impl_lu_tridiagonal_work!(f64, lapack_sys::dgttrf_);
101impl_lu_tridiagonal_work!(f32, lapack_sys::sgttrf_);