lax/tridiagonal/
solve.rs

1use crate::{error::*, layout::*, *};
2use cauchy::*;
3
4pub trait SolveTridiagonalImpl: Scalar {
5    fn solve_tridiagonal(
6        lu: &LUFactorizedTridiagonal<Self>,
7        bl: MatrixLayout,
8        t: Transpose,
9        b: &mut [Self],
10    ) -> Result<()>;
11}
12
13macro_rules! impl_solve_tridiagonal {
14    ($s:ty, $trs:path) => {
15        impl SolveTridiagonalImpl for $s {
16            fn solve_tridiagonal(
17                lu: &LUFactorizedTridiagonal<Self>,
18                b_layout: MatrixLayout,
19                t: Transpose,
20                b: &mut [Self],
21            ) -> Result<()> {
22                let (n, _) = lu.a.l.size();
23                let ipiv = &lu.ipiv;
24                // Transpose if b is C-continuous
25                let mut b_t = None;
26                let b_layout = match b_layout {
27                    MatrixLayout::C { .. } => {
28                        let (layout, t) = transpose(b_layout, b);
29                        b_t = Some(t);
30                        layout
31                    }
32                    MatrixLayout::F { .. } => b_layout,
33                };
34                let (ldb, nrhs) = b_layout.size();
35                let mut info = 0;
36                unsafe {
37                    $trs(
38                        t.as_ptr(),
39                        &n,
40                        &nrhs,
41                        AsPtr::as_ptr(&lu.a.dl),
42                        AsPtr::as_ptr(&lu.a.d),
43                        AsPtr::as_ptr(&lu.a.du),
44                        AsPtr::as_ptr(&lu.du2),
45                        ipiv.as_ptr(),
46                        AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)),
47                        &ldb,
48                        &mut info,
49                    );
50                }
51                info.as_lapack_result()?;
52                if let Some(b_t) = b_t {
53                    transpose_over(b_layout, &b_t, b);
54                }
55                Ok(())
56            }
57        }
58    };
59}
60
61impl_solve_tridiagonal!(c64, lapack_sys::zgttrs_);
62impl_solve_tridiagonal!(c32, lapack_sys::cgttrs_);
63impl_solve_tridiagonal!(f64, lapack_sys::dgttrs_);
64impl_solve_tridiagonal!(f32, lapack_sys::sgttrs_);