lax/
solve.rs

1//! Solve linear equations using LU-decomposition
2
3use crate::{error::*, layout::MatrixLayout, *};
4use cauchy::*;
5use num_traits::{ToPrimitive, Zero};
6
7/// Helper trait to abstract `*getrf` LAPACK routines for implementing [Lapack::lu]
8///
9/// LAPACK correspondance
10/// ----------------------
11///
12/// | f32    | f64    | c32    | c64    |
13/// |:-------|:-------|:-------|:-------|
14/// | sgetrf | dgetrf | cgetrf | zgetrf |
15///
16pub trait LuImpl: Scalar {
17    fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot>;
18}
19
20macro_rules! impl_lu {
21    ($scalar:ty, $getrf:path) => {
22        impl LuImpl for $scalar {
23            fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot> {
24                let (row, col) = l.size();
25                assert_eq!(a.len() as i32, row * col);
26                if row == 0 || col == 0 {
27                    // Do nothing for empty matrix
28                    return Ok(Vec::new());
29                }
30                let k = ::std::cmp::min(row, col);
31                let mut ipiv = vec_uninit(k as usize);
32                let mut info = 0;
33                unsafe {
34                    $getrf(
35                        &l.lda(),
36                        &l.len(),
37                        AsPtr::as_mut_ptr(a),
38                        &l.lda(),
39                        AsPtr::as_mut_ptr(&mut ipiv),
40                        &mut info,
41                    )
42                };
43                info.as_lapack_result()?;
44                let ipiv = unsafe { ipiv.assume_init() };
45                Ok(ipiv)
46            }
47        }
48    };
49}
50
51impl_lu!(c64, lapack_sys::zgetrf_);
52impl_lu!(c32, lapack_sys::cgetrf_);
53impl_lu!(f64, lapack_sys::dgetrf_);
54impl_lu!(f32, lapack_sys::sgetrf_);
55
56#[cfg_attr(doc, katexit::katexit)]
57/// Helper trait to abstract `*getrs` LAPACK routines for implementing [Lapack::solve]
58///
59/// If the array has C layout, then it needs to be handled
60/// specially, since LAPACK expects a Fortran-layout array.
61/// Reinterpreting a C layout array as Fortran layout is
62/// equivalent to transposing it. So, we can handle the "no
63/// transpose" and "transpose" cases by swapping to "transpose"
64/// or "no transpose", respectively. For the "Hermite" case, we
65/// can take advantage of the following:
66///
67/// $$
68/// \begin{align*}
69///   A^H x &= b \\\\
70///   \Leftrightarrow \overline{A^T} x &= b \\\\
71///   \Leftrightarrow \overline{\overline{A^T} x} &= \overline{b} \\\\
72///   \Leftrightarrow \overline{\overline{A^T}} \overline{x} &= \overline{b} \\\\
73///   \Leftrightarrow A^T \overline{x} &= \overline{b}
74/// \end{align*}
75/// $$
76///
77/// So, we can handle this case by switching to "no transpose"
78/// (which is equivalent to transposing the array since it will
79/// be reinterpreted as Fortran layout) and applying the
80/// elementwise conjugate to `x` and `b`.
81///
82pub trait SolveImpl: Scalar {
83    /// LAPACK correspondance
84    /// ----------------------
85    ///
86    /// | f32    | f64    | c32    | c64    |
87    /// |:-------|:-------|:-------|:-------|
88    /// | sgetrs | dgetrs | cgetrs | zgetrs |
89    ///
90    fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>;
91}
92
93macro_rules! impl_solve {
94    ($scalar:ty, $getrs:path) => {
95        impl SolveImpl for $scalar {
96            fn solve(
97                l: MatrixLayout,
98                t: Transpose,
99                a: &[Self],
100                ipiv: &Pivot,
101                b: &mut [Self],
102            ) -> Result<()> {
103                let (t, conj) = match l {
104                    MatrixLayout::C { .. } => match t {
105                        Transpose::No => (Transpose::Transpose, false),
106                        Transpose::Transpose => (Transpose::No, false),
107                        Transpose::Hermite => (Transpose::No, true),
108                    },
109                    MatrixLayout::F { .. } => (t, false),
110                };
111                let (n, _) = l.size();
112                let nrhs = 1;
113                let ldb = l.lda();
114                let mut info = 0;
115                if conj {
116                    for b_elem in &mut *b {
117                        *b_elem = b_elem.conj();
118                    }
119                }
120                unsafe {
121                    $getrs(
122                        t.as_ptr(),
123                        &n,
124                        &nrhs,
125                        AsPtr::as_ptr(a),
126                        &l.lda(),
127                        ipiv.as_ptr(),
128                        AsPtr::as_mut_ptr(b),
129                        &ldb,
130                        &mut info,
131                    )
132                };
133                if conj {
134                    for b_elem in &mut *b {
135                        *b_elem = b_elem.conj();
136                    }
137                }
138                info.as_lapack_result()?;
139                Ok(())
140            }
141        }
142    };
143} // impl_solve!
144
145impl_solve!(f64, lapack_sys::dgetrs_);
146impl_solve!(f32, lapack_sys::sgetrs_);
147impl_solve!(c64, lapack_sys::zgetrs_);
148impl_solve!(c32, lapack_sys::cgetrs_);
149
150/// Working memory for computing inverse matrix
151pub struct InvWork<T: Scalar> {
152    pub layout: MatrixLayout,
153    pub work: Vec<MaybeUninit<T>>,
154}
155
156/// Helper trait to abstract `*getri` LAPACK rotuines for implementing [Lapack::inv]
157///
158/// LAPACK correspondance
159/// ----------------------
160///
161/// | f32    | f64    | c32    | c64    |
162/// |:-------|:-------|:-------|:-------|
163/// | sgetri | dgetri | cgetri | zgetri |
164///
165pub trait InvWorkImpl: Sized {
166    type Elem: Scalar;
167    fn new(layout: MatrixLayout) -> Result<Self>;
168    fn calc(&mut self, a: &mut [Self::Elem], p: &Pivot) -> Result<()>;
169}
170
171macro_rules! impl_inv_work {
172    ($s:ty, $tri:path) => {
173        impl InvWorkImpl for InvWork<$s> {
174            type Elem = $s;
175
176            fn new(layout: MatrixLayout) -> Result<Self> {
177                let (n, _) = layout.size();
178                let mut info = 0;
179                let mut work_size = [Self::Elem::zero()];
180                unsafe {
181                    $tri(
182                        &n,
183                        std::ptr::null_mut(),
184                        &layout.lda(),
185                        std::ptr::null(),
186                        AsPtr::as_mut_ptr(&mut work_size),
187                        &(-1),
188                        &mut info,
189                    )
190                };
191                info.as_lapack_result()?;
192                let lwork = work_size[0].to_usize().unwrap();
193                let work = vec_uninit(lwork);
194                Ok(InvWork { layout, work })
195            }
196
197            fn calc(&mut self, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()> {
198                if self.layout.len() == 0 {
199                    return Ok(());
200                }
201                let lwork = self.work.len().to_i32().unwrap();
202                let mut info = 0;
203                unsafe {
204                    $tri(
205                        &self.layout.len(),
206                        AsPtr::as_mut_ptr(a),
207                        &self.layout.lda(),
208                        ipiv.as_ptr(),
209                        AsPtr::as_mut_ptr(&mut self.work),
210                        &lwork,
211                        &mut info,
212                    )
213                };
214                info.as_lapack_result()?;
215                Ok(())
216            }
217        }
218    };
219}
220
221impl_inv_work!(c64, lapack_sys::zgetri_);
222impl_inv_work!(c32, lapack_sys::cgetri_);
223impl_inv_work!(f64, lapack_sys::dgetri_);
224impl_inv_work!(f32, lapack_sys::sgetri_);