lax/tridiagonal/
solve.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
use crate::{error::*, layout::*, *};
use cauchy::*;

pub trait SolveTridiagonalImpl: Scalar {
    fn solve_tridiagonal(
        lu: &LUFactorizedTridiagonal<Self>,
        bl: MatrixLayout,
        t: Transpose,
        b: &mut [Self],
    ) -> Result<()>;
}

macro_rules! impl_solve_tridiagonal {
    ($s:ty, $trs:path) => {
        impl SolveTridiagonalImpl for $s {
            fn solve_tridiagonal(
                lu: &LUFactorizedTridiagonal<Self>,
                b_layout: MatrixLayout,
                t: Transpose,
                b: &mut [Self],
            ) -> Result<()> {
                let (n, _) = lu.a.l.size();
                let ipiv = &lu.ipiv;
                // Transpose if b is C-continuous
                let mut b_t = None;
                let b_layout = match b_layout {
                    MatrixLayout::C { .. } => {
                        let (layout, t) = transpose(b_layout, b);
                        b_t = Some(t);
                        layout
                    }
                    MatrixLayout::F { .. } => b_layout,
                };
                let (ldb, nrhs) = b_layout.size();
                let mut info = 0;
                unsafe {
                    $trs(
                        t.as_ptr(),
                        &n,
                        &nrhs,
                        AsPtr::as_ptr(&lu.a.dl),
                        AsPtr::as_ptr(&lu.a.d),
                        AsPtr::as_ptr(&lu.a.du),
                        AsPtr::as_ptr(&lu.du2),
                        ipiv.as_ptr(),
                        AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)),
                        &ldb,
                        &mut info,
                    );
                }
                info.as_lapack_result()?;
                if let Some(b_t) = b_t {
                    transpose_over(b_layout, &b_t, b);
                }
                Ok(())
            }
        }
    };
}

impl_solve_tridiagonal!(c64, lapack_sys::zgttrs_);
impl_solve_tridiagonal!(c32, lapack_sys::cgttrs_);
impl_solve_tridiagonal!(f64, lapack_sys::dgttrs_);
impl_solve_tridiagonal!(f32, lapack_sys::sgttrs_);