lax/solve.rs
1//! Solve linear equations using LU-decomposition
2
3use crate::{error::*, layout::MatrixLayout, *};
4use cauchy::*;
5use num_traits::{ToPrimitive, Zero};
6
7/// Helper trait to abstract `*getrf` LAPACK routines for implementing [Lapack::lu]
8///
9/// LAPACK correspondance
10/// ----------------------
11///
12/// | f32 | f64 | c32 | c64 |
13/// |:-------|:-------|:-------|:-------|
14/// | sgetrf | dgetrf | cgetrf | zgetrf |
15///
16pub trait LuImpl: Scalar {
17 fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot>;
18}
19
20macro_rules! impl_lu {
21 ($scalar:ty, $getrf:path) => {
22 impl LuImpl for $scalar {
23 fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot> {
24 let (row, col) = l.size();
25 assert_eq!(a.len() as i32, row * col);
26 if row == 0 || col == 0 {
27 // Do nothing for empty matrix
28 return Ok(Vec::new());
29 }
30 let k = ::std::cmp::min(row, col);
31 let mut ipiv = vec_uninit(k as usize);
32 let mut info = 0;
33 unsafe {
34 $getrf(
35 &l.lda(),
36 &l.len(),
37 AsPtr::as_mut_ptr(a),
38 &l.lda(),
39 AsPtr::as_mut_ptr(&mut ipiv),
40 &mut info,
41 )
42 };
43 info.as_lapack_result()?;
44 let ipiv = unsafe { ipiv.assume_init() };
45 Ok(ipiv)
46 }
47 }
48 };
49}
50
51impl_lu!(c64, lapack_sys::zgetrf_);
52impl_lu!(c32, lapack_sys::cgetrf_);
53impl_lu!(f64, lapack_sys::dgetrf_);
54impl_lu!(f32, lapack_sys::sgetrf_);
55
56#[cfg_attr(doc, katexit::katexit)]
57/// Helper trait to abstract `*getrs` LAPACK routines for implementing [Lapack::solve]
58///
59/// If the array has C layout, then it needs to be handled
60/// specially, since LAPACK expects a Fortran-layout array.
61/// Reinterpreting a C layout array as Fortran layout is
62/// equivalent to transposing it. So, we can handle the "no
63/// transpose" and "transpose" cases by swapping to "transpose"
64/// or "no transpose", respectively. For the "Hermite" case, we
65/// can take advantage of the following:
66///
67/// $$
68/// \begin{align*}
69/// A^H x &= b \\\\
70/// \Leftrightarrow \overline{A^T} x &= b \\\\
71/// \Leftrightarrow \overline{\overline{A^T} x} &= \overline{b} \\\\
72/// \Leftrightarrow \overline{\overline{A^T}} \overline{x} &= \overline{b} \\\\
73/// \Leftrightarrow A^T \overline{x} &= \overline{b}
74/// \end{align*}
75/// $$
76///
77/// So, we can handle this case by switching to "no transpose"
78/// (which is equivalent to transposing the array since it will
79/// be reinterpreted as Fortran layout) and applying the
80/// elementwise conjugate to `x` and `b`.
81///
82pub trait SolveImpl: Scalar {
83 /// LAPACK correspondance
84 /// ----------------------
85 ///
86 /// | f32 | f64 | c32 | c64 |
87 /// |:-------|:-------|:-------|:-------|
88 /// | sgetrs | dgetrs | cgetrs | zgetrs |
89 ///
90 fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>;
91}
92
93macro_rules! impl_solve {
94 ($scalar:ty, $getrs:path) => {
95 impl SolveImpl for $scalar {
96 fn solve(
97 l: MatrixLayout,
98 t: Transpose,
99 a: &[Self],
100 ipiv: &Pivot,
101 b: &mut [Self],
102 ) -> Result<()> {
103 let (t, conj) = match l {
104 MatrixLayout::C { .. } => match t {
105 Transpose::No => (Transpose::Transpose, false),
106 Transpose::Transpose => (Transpose::No, false),
107 Transpose::Hermite => (Transpose::No, true),
108 },
109 MatrixLayout::F { .. } => (t, false),
110 };
111 let (n, _) = l.size();
112 let nrhs = 1;
113 let ldb = l.lda();
114 let mut info = 0;
115 if conj {
116 for b_elem in &mut *b {
117 *b_elem = b_elem.conj();
118 }
119 }
120 unsafe {
121 $getrs(
122 t.as_ptr(),
123 &n,
124 &nrhs,
125 AsPtr::as_ptr(a),
126 &l.lda(),
127 ipiv.as_ptr(),
128 AsPtr::as_mut_ptr(b),
129 &ldb,
130 &mut info,
131 )
132 };
133 if conj {
134 for b_elem in &mut *b {
135 *b_elem = b_elem.conj();
136 }
137 }
138 info.as_lapack_result()?;
139 Ok(())
140 }
141 }
142 };
143} // impl_solve!
144
145impl_solve!(f64, lapack_sys::dgetrs_);
146impl_solve!(f32, lapack_sys::sgetrs_);
147impl_solve!(c64, lapack_sys::zgetrs_);
148impl_solve!(c32, lapack_sys::cgetrs_);
149
150/// Working memory for computing inverse matrix
151pub struct InvWork<T: Scalar> {
152 pub layout: MatrixLayout,
153 pub work: Vec<MaybeUninit<T>>,
154}
155
156/// Helper trait to abstract `*getri` LAPACK rotuines for implementing [Lapack::inv]
157///
158/// LAPACK correspondance
159/// ----------------------
160///
161/// | f32 | f64 | c32 | c64 |
162/// |:-------|:-------|:-------|:-------|
163/// | sgetri | dgetri | cgetri | zgetri |
164///
165pub trait InvWorkImpl: Sized {
166 type Elem: Scalar;
167 fn new(layout: MatrixLayout) -> Result<Self>;
168 fn calc(&mut self, a: &mut [Self::Elem], p: &Pivot) -> Result<()>;
169}
170
171macro_rules! impl_inv_work {
172 ($s:ty, $tri:path) => {
173 impl InvWorkImpl for InvWork<$s> {
174 type Elem = $s;
175
176 fn new(layout: MatrixLayout) -> Result<Self> {
177 let (n, _) = layout.size();
178 let mut info = 0;
179 let mut work_size = [Self::Elem::zero()];
180 unsafe {
181 $tri(
182 &n,
183 std::ptr::null_mut(),
184 &layout.lda(),
185 std::ptr::null(),
186 AsPtr::as_mut_ptr(&mut work_size),
187 &(-1),
188 &mut info,
189 )
190 };
191 info.as_lapack_result()?;
192 let lwork = work_size[0].to_usize().unwrap();
193 let work = vec_uninit(lwork);
194 Ok(InvWork { layout, work })
195 }
196
197 fn calc(&mut self, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()> {
198 if self.layout.len() == 0 {
199 return Ok(());
200 }
201 let lwork = self.work.len().to_i32().unwrap();
202 let mut info = 0;
203 unsafe {
204 $tri(
205 &self.layout.len(),
206 AsPtr::as_mut_ptr(a),
207 &self.layout.lda(),
208 ipiv.as_ptr(),
209 AsPtr::as_mut_ptr(&mut self.work),
210 &lwork,
211 &mut info,
212 )
213 };
214 info.as_lapack_result()?;
215 Ok(())
216 }
217 }
218 };
219}
220
221impl_inv_work!(c64, lapack_sys::zgetri_);
222impl_inv_work!(c32, lapack_sys::cgetri_);
223impl_inv_work!(f64, lapack_sys::dgetri_);
224impl_inv_work!(f32, lapack_sys::sgetri_);