lax/
triangular.rs

1//! Linear problem for triangular matrices
2
3use crate::{error::*, layout::*, *};
4use cauchy::*;
5
6/// Solve linear problem for triangular matrices
7///
8/// LAPACK correspondance
9/// ----------------------
10///
11/// | f32    | f64    | c32    | c64    |
12/// |:-------|:-------|:-------|:-------|
13/// | strtrs | dtrtrs | ctrtrs | ztrtrs |
14///
15pub trait SolveTriangularImpl: Scalar {
16    fn solve_triangular(
17        al: MatrixLayout,
18        bl: MatrixLayout,
19        uplo: UPLO,
20        d: Diag,
21        a: &[Self],
22        b: &mut [Self],
23    ) -> Result<()>;
24}
25
26macro_rules! impl_triangular {
27    ($scalar:ty, $trtrs:path) => {
28        impl SolveTriangularImpl for $scalar {
29            fn solve_triangular(
30                a_layout: MatrixLayout,
31                b_layout: MatrixLayout,
32                uplo: UPLO,
33                diag: Diag,
34                a: &[Self],
35                b: &mut [Self],
36            ) -> Result<()> {
37                // Transpose if a is C-continuous
38                let mut a_t = None;
39                let a_layout = match a_layout {
40                    MatrixLayout::C { .. } => {
41                        let (layout, t) = transpose(a_layout, a);
42                        a_t = Some(t);
43                        layout
44                    }
45                    MatrixLayout::F { .. } => a_layout,
46                };
47
48                // Transpose if b is C-continuous
49                let mut b_t = None;
50                let b_layout = match b_layout {
51                    MatrixLayout::C { .. } => {
52                        let (layout, t) = transpose(b_layout, b);
53                        b_t = Some(t);
54                        layout
55                    }
56                    MatrixLayout::F { .. } => b_layout,
57                };
58
59                let (m, n) = a_layout.size();
60                let (n_, nrhs) = b_layout.size();
61                assert_eq!(n, n_);
62
63                let mut info = 0;
64                unsafe {
65                    $trtrs(
66                        uplo.as_ptr(),
67                        Transpose::No.as_ptr(),
68                        diag.as_ptr(),
69                        &m,
70                        &nrhs,
71                        AsPtr::as_ptr(a_t.as_ref().map(|v| v.as_slice()).unwrap_or(a)),
72                        &a_layout.lda(),
73                        AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)),
74                        &b_layout.lda(),
75                        &mut info,
76                    );
77                }
78                info.as_lapack_result()?;
79
80                // Re-transpose b
81                if let Some(b_t) = b_t {
82                    transpose_over(b_layout, &b_t, b);
83                }
84                Ok(())
85            }
86        }
87    };
88} // impl_triangular!
89
90impl_triangular!(f64, lapack_sys::dtrtrs_);
91impl_triangular!(f32, lapack_sys::strtrs_);
92impl_triangular!(c64, lapack_sys::ztrtrs_);
93impl_triangular!(c32, lapack_sys::ctrtrs_);