lax/
least_squares.rs

1//! Least squares
2
3use crate::{error::*, layout::*, *};
4use cauchy::*;
5use num_traits::{ToPrimitive, Zero};
6
7/// Result of LeastSquares
8pub struct LeastSquaresOwned<A: Scalar> {
9    /// singular values
10    pub singular_values: Vec<A::Real>,
11    /// The rank of the input matrix A
12    pub rank: i32,
13}
14
15/// Result of LeastSquares
16pub struct LeastSquaresRef<'work, A: Scalar> {
17    /// singular values
18    pub singular_values: &'work [A::Real],
19    /// The rank of the input matrix A
20    pub rank: i32,
21}
22
23pub struct LeastSquaresWork<T: Scalar> {
24    pub a_layout: MatrixLayout,
25    pub b_layout: MatrixLayout,
26    pub singular_values: Vec<MaybeUninit<T::Real>>,
27    pub work: Vec<MaybeUninit<T>>,
28    pub iwork: Vec<MaybeUninit<i32>>,
29    pub rwork: Option<Vec<MaybeUninit<T::Real>>>,
30}
31
32pub trait LeastSquaresWorkImpl: Sized {
33    type Elem: Scalar;
34    fn new(a_layout: MatrixLayout, b_layout: MatrixLayout) -> Result<Self>;
35    fn calc(
36        &mut self,
37        a: &mut [Self::Elem],
38        b: &mut [Self::Elem],
39    ) -> Result<LeastSquaresRef<Self::Elem>>;
40    fn eval(
41        self,
42        a: &mut [Self::Elem],
43        b: &mut [Self::Elem],
44    ) -> Result<LeastSquaresOwned<Self::Elem>>;
45}
46
47macro_rules! impl_least_squares_work_c {
48    ($c:ty, $lsd:path) => {
49        impl LeastSquaresWorkImpl for LeastSquaresWork<$c> {
50            type Elem = $c;
51
52            fn new(a_layout: MatrixLayout, b_layout: MatrixLayout) -> Result<Self> {
53                let (m, n) = a_layout.size();
54                let (m_, nrhs) = b_layout.size();
55                let k = m.min(n);
56                assert!(m_ >= m);
57
58                let rcond = -1.;
59                let mut singular_values = vec_uninit(k as usize);
60                let mut rank: i32 = 0;
61
62                // eval work size
63                let mut info = 0;
64                let mut work_size = [Self::Elem::zero()];
65                let mut iwork_size = [0];
66                let mut rwork = [<Self::Elem as Scalar>::Real::zero()];
67                unsafe {
68                    $lsd(
69                        &m,
70                        &n,
71                        &nrhs,
72                        std::ptr::null_mut(),
73                        &m,
74                        std::ptr::null_mut(),
75                        &m_,
76                        AsPtr::as_mut_ptr(&mut singular_values),
77                        &rcond,
78                        &mut rank,
79                        AsPtr::as_mut_ptr(&mut work_size),
80                        &(-1),
81                        AsPtr::as_mut_ptr(&mut rwork),
82                        iwork_size.as_mut_ptr(),
83                        &mut info,
84                    )
85                };
86                info.as_lapack_result()?;
87
88                let lwork = work_size[0].to_usize().unwrap();
89                let liwork = iwork_size[0].to_usize().unwrap();
90                let lrwork = rwork[0].to_usize().unwrap();
91
92                let work = vec_uninit(lwork);
93                let iwork = vec_uninit(liwork);
94                let rwork = vec_uninit(lrwork);
95
96                Ok(LeastSquaresWork {
97                    a_layout,
98                    b_layout,
99                    work,
100                    iwork,
101                    rwork: Some(rwork),
102                    singular_values,
103                })
104            }
105
106            fn calc(
107                &mut self,
108                a: &mut [Self::Elem],
109                b: &mut [Self::Elem],
110            ) -> Result<LeastSquaresRef<Self::Elem>> {
111                let (m, n) = self.a_layout.size();
112                let (m_, nrhs) = self.b_layout.size();
113                assert!(m_ >= m);
114
115                let lwork = self.work.len().to_i32().unwrap();
116
117                // Transpose if a is C-continuous
118                let mut a_t = None;
119                let _ = match self.a_layout {
120                    MatrixLayout::C { .. } => {
121                        let (layout, t) = transpose(self.a_layout, a);
122                        a_t = Some(t);
123                        layout
124                    }
125                    MatrixLayout::F { .. } => self.a_layout,
126                };
127
128                // Transpose if b is C-continuous
129                let mut b_t = None;
130                let b_layout = match self.b_layout {
131                    MatrixLayout::C { .. } => {
132                        let (layout, t) = transpose(self.b_layout, b);
133                        b_t = Some(t);
134                        layout
135                    }
136                    MatrixLayout::F { .. } => self.b_layout,
137                };
138
139                let rcond: <Self::Elem as Scalar>::Real = -1.;
140                let mut rank: i32 = 0;
141
142                let mut info = 0;
143                unsafe {
144                    $lsd(
145                        &m,
146                        &n,
147                        &nrhs,
148                        AsPtr::as_mut_ptr(a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a)),
149                        &m,
150                        AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)),
151                        &m_,
152                        AsPtr::as_mut_ptr(&mut self.singular_values),
153                        &rcond,
154                        &mut rank,
155                        AsPtr::as_mut_ptr(&mut self.work),
156                        &lwork,
157                        AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()),
158                        AsPtr::as_mut_ptr(&mut self.iwork),
159                        &mut info,
160                    );
161                }
162                info.as_lapack_result()?;
163
164                let singular_values = unsafe { self.singular_values.slice_assume_init_ref() };
165
166                // Skip a_t -> a transpose because A has been destroyed
167                // Re-transpose b
168                if let Some(b_t) = b_t {
169                    transpose_over(b_layout, &b_t, b);
170                }
171
172                Ok(LeastSquaresRef {
173                    singular_values,
174                    rank,
175                })
176            }
177
178            fn eval(
179                mut self,
180                a: &mut [Self::Elem],
181                b: &mut [Self::Elem],
182            ) -> Result<LeastSquaresOwned<Self::Elem>> {
183                let LeastSquaresRef { rank, .. } = self.calc(a, b)?;
184                let singular_values = unsafe { self.singular_values.assume_init() };
185                Ok(LeastSquaresOwned {
186                    singular_values,
187                    rank,
188                })
189            }
190        }
191    };
192}
193impl_least_squares_work_c!(c64, lapack_sys::zgelsd_);
194impl_least_squares_work_c!(c32, lapack_sys::cgelsd_);
195
196macro_rules! impl_least_squares_work_r {
197    ($c:ty, $lsd:path) => {
198        impl LeastSquaresWorkImpl for LeastSquaresWork<$c> {
199            type Elem = $c;
200
201            fn new(a_layout: MatrixLayout, b_layout: MatrixLayout) -> Result<Self> {
202                let (m, n) = a_layout.size();
203                let (m_, nrhs) = b_layout.size();
204                let k = m.min(n);
205                assert!(m_ >= m);
206
207                let rcond = -1.;
208                let mut singular_values = vec_uninit(k as usize);
209                let mut rank: i32 = 0;
210
211                // eval work size
212                let mut info = 0;
213                let mut work_size = [Self::Elem::zero()];
214                let mut iwork_size = [0];
215                unsafe {
216                    $lsd(
217                        &m,
218                        &n,
219                        &nrhs,
220                        std::ptr::null_mut(),
221                        &m,
222                        std::ptr::null_mut(),
223                        &m_,
224                        AsPtr::as_mut_ptr(&mut singular_values),
225                        &rcond,
226                        &mut rank,
227                        AsPtr::as_mut_ptr(&mut work_size),
228                        &(-1),
229                        iwork_size.as_mut_ptr(),
230                        &mut info,
231                    )
232                };
233                info.as_lapack_result()?;
234
235                let lwork = work_size[0].to_usize().unwrap();
236                let liwork = iwork_size[0].to_usize().unwrap();
237
238                let work = vec_uninit(lwork);
239                let iwork = vec_uninit(liwork);
240
241                Ok(LeastSquaresWork {
242                    a_layout,
243                    b_layout,
244                    work,
245                    iwork,
246                    rwork: None,
247                    singular_values,
248                })
249            }
250
251            fn calc(
252                &mut self,
253                a: &mut [Self::Elem],
254                b: &mut [Self::Elem],
255            ) -> Result<LeastSquaresRef<Self::Elem>> {
256                let (m, n) = self.a_layout.size();
257                let (m_, nrhs) = self.b_layout.size();
258                assert!(m_ >= m);
259
260                let lwork = self.work.len().to_i32().unwrap();
261
262                // Transpose if a is C-continuous
263                let mut a_t = None;
264                let _ = match self.a_layout {
265                    MatrixLayout::C { .. } => {
266                        let (layout, t) = transpose(self.a_layout, a);
267                        a_t = Some(t);
268                        layout
269                    }
270                    MatrixLayout::F { .. } => self.a_layout,
271                };
272
273                // Transpose if b is C-continuous
274                let mut b_t = None;
275                let b_layout = match self.b_layout {
276                    MatrixLayout::C { .. } => {
277                        let (layout, t) = transpose(self.b_layout, b);
278                        b_t = Some(t);
279                        layout
280                    }
281                    MatrixLayout::F { .. } => self.b_layout,
282                };
283
284                let rcond: <Self::Elem as Scalar>::Real = -1.;
285                let mut rank: i32 = 0;
286
287                let mut info = 0;
288                unsafe {
289                    $lsd(
290                        &m,
291                        &n,
292                        &nrhs,
293                        AsPtr::as_mut_ptr(a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a)),
294                        &m,
295                        AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)),
296                        &m_,
297                        AsPtr::as_mut_ptr(&mut self.singular_values),
298                        &rcond,
299                        &mut rank,
300                        AsPtr::as_mut_ptr(&mut self.work),
301                        &lwork,
302                        AsPtr::as_mut_ptr(&mut self.iwork),
303                        &mut info,
304                    );
305                }
306                info.as_lapack_result()?;
307
308                let singular_values = unsafe { self.singular_values.slice_assume_init_ref() };
309
310                // Skip a_t -> a transpose because A has been destroyed
311                // Re-transpose b
312                if let Some(b_t) = b_t {
313                    transpose_over(b_layout, &b_t, b);
314                }
315
316                Ok(LeastSquaresRef {
317                    singular_values,
318                    rank,
319                })
320            }
321
322            fn eval(
323                mut self,
324                a: &mut [Self::Elem],
325                b: &mut [Self::Elem],
326            ) -> Result<LeastSquaresOwned<Self::Elem>> {
327                let LeastSquaresRef { rank, .. } = self.calc(a, b)?;
328                let singular_values = unsafe { self.singular_values.assume_init() };
329                Ok(LeastSquaresOwned {
330                    singular_values,
331                    rank,
332                })
333            }
334        }
335    };
336}
337impl_least_squares_work_r!(f64, lapack_sys::dgelsd_);
338impl_least_squares_work_r!(f32, lapack_sys::sgelsd_);