1use crate::{error::*, layout::*, *};
2use cauchy::*;
3
4pub trait SolveTridiagonalImpl: Scalar {
5 fn solve_tridiagonal(
6 lu: &LUFactorizedTridiagonal<Self>,
7 bl: MatrixLayout,
8 t: Transpose,
9 b: &mut [Self],
10 ) -> Result<()>;
11}
12
13macro_rules! impl_solve_tridiagonal {
14 ($s:ty, $trs:path) => {
15 impl SolveTridiagonalImpl for $s {
16 fn solve_tridiagonal(
17 lu: &LUFactorizedTridiagonal<Self>,
18 b_layout: MatrixLayout,
19 t: Transpose,
20 b: &mut [Self],
21 ) -> Result<()> {
22 let (n, _) = lu.a.l.size();
23 let ipiv = &lu.ipiv;
24 let mut b_t = None;
26 let b_layout = match b_layout {
27 MatrixLayout::C { .. } => {
28 let (layout, t) = transpose(b_layout, b);
29 b_t = Some(t);
30 layout
31 }
32 MatrixLayout::F { .. } => b_layout,
33 };
34 let (ldb, nrhs) = b_layout.size();
35 let mut info = 0;
36 unsafe {
37 $trs(
38 t.as_ptr(),
39 &n,
40 &nrhs,
41 AsPtr::as_ptr(&lu.a.dl),
42 AsPtr::as_ptr(&lu.a.d),
43 AsPtr::as_ptr(&lu.a.du),
44 AsPtr::as_ptr(&lu.du2),
45 ipiv.as_ptr(),
46 AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)),
47 &ldb,
48 &mut info,
49 );
50 }
51 info.as_lapack_result()?;
52 if let Some(b_t) = b_t {
53 transpose_over(b_layout, &b_t, b);
54 }
55 Ok(())
56 }
57 }
58 };
59}
60
61impl_solve_tridiagonal!(c64, lapack_sys::zgttrs_);
62impl_solve_tridiagonal!(c32, lapack_sys::cgttrs_);
63impl_solve_tridiagonal!(f64, lapack_sys::dgttrs_);
64impl_solve_tridiagonal!(f32, lapack_sys::sgttrs_);