ndarray_linalg/
least_squares.rs

1//! # Least Squares
2//!
3//! Compute a least-squares solution to the equation Ax = b.
4//! Compute a vector x such that the 2-norm ``|b - A x|`` is minimized.
5//!
6//! Finding the least squares solutions is implemented as traits, meaning
7//! that to solve `A x = b` for a matrix `A` and a RHS `b`, we call
8//! `let result = A.least_squares(&b);`. This returns a `result` of
9//! type `LeastSquaresResult`, the solution for the least square problem
10//! is in `result.solution`.
11//!
12//! There are three traits, `LeastSquaresSvd` with the method `least_squares`,
13//! which operates on immutable references, `LeastSquaresInto` with the method
14//! `least_squares_into`, which takes ownership over both the array `A` and the
15//! RHS `b` and `LeastSquaresSvdInPlace` with the method `least_squares_in_place`,
16//! which operates on mutable references for `A` and `b` and destroys these when
17//! solving the least squares problem. `LeastSquaresSvdInto` and
18//! `LeastSquaresSvdInPlace` avoid an extra allocation for `A` and `b` which
19//! `LeastSquaresSvd` has do perform to preserve the values in `A` and `b`.
20//!
21//! All methods use the Lapacke family of methods `*gelsd` which solves the least
22//! squares problem using the SVD with a divide-and-conquer strategy.
23//!
24//! The traits are implemented for value types `f32`, `f64`, `c32` and `c64`
25//! and vector or matrix right-hand-sides (`ArrayBase<S, Ix1>` or `ArrayBase<S, Ix2>`).
26//!
27//! ## Example
28//! ```rust
29//! use approx::AbsDiffEq; // for abs_diff_eq
30//! use ndarray::{array, Array1, Array2};
31//! use ndarray_linalg::{LeastSquaresSvd, LeastSquaresSvdInto, LeastSquaresSvdInPlace};
32//!
33//! let a: Array2<f64> = array![
34//!     [1., 1., 1.],
35//!     [2., 3., 4.],
36//!     [3., 5., 2.],
37//!     [4., 2., 5.],
38//!     [5., 4., 3.]
39//! ];
40//! // solving for a single right-hand side
41//! let b: Array1<f64> = array![-10., 12., 14., 16., 18.];
42//! let expected: Array1<f64> = array![2., 1., 1.];
43//! let result = a.least_squares(&b).unwrap();
44//! assert!(result.solution.abs_diff_eq(&expected, 1e-12));
45//!
46//! // solving for two right-hand sides at once
47//! let b_2: Array2<f64> =
48//!     array![[-10., -3.], [12., 14.], [14., 12.], [16., 16.], [18., 16.]];
49//! let expected_2: Array2<f64> = array![[2., 1.], [1., 1.], [1., 2.]];
50//! let result_2 = a.least_squares(&b_2).unwrap();
51//! assert!(result_2.solution.abs_diff_eq(&expected_2, 1e-12));
52//!
53//! // using `least_squares_in_place` which overwrites its arguments
54//! let mut a_3 = a.clone();
55//! let mut b_3 = b.clone();
56//! let result_3 = a_3.least_squares_in_place(&mut b_3).unwrap();
57//!
58//! // using `least_squares_into` which consumes its arguments
59//! let result_4 = a.least_squares_into(b).unwrap();
60//! // `a` and `b` have been moved, no longer valid
61//! ```
62
63use lax::*;
64use ndarray::*;
65
66use crate::error::*;
67use crate::layout::*;
68use crate::types::*;
69
70/// Result of a LeastSquares computation
71///
72/// Takes two type parameters, `E`, the element type of the matrix
73/// (one of `f32`, `f64`, `c32` or `c64`) and `I`, the dimension of
74/// b in the equation `Ax = b` (one of `Ix1` or `Ix2`). If `I` is `Ix1`,
75/// the  right-hand-side (RHS) is a `n x 1` column vector and the solution
76/// is a `m x 1` column vector. If `I` is `Ix2`, the RHS is a `n x k` matrix
77/// (which can be seen as solving `Ax = b` k times for different b) and
78/// the solution is a `m x k` matrix.
79#[derive(Debug, Clone)]
80pub struct LeastSquaresResult<E: Scalar, I: Dimension> {
81    /// The singular values of the matrix A in `Ax = b`
82    pub singular_values: Array1<E::Real>,
83    /// The solution vector or matrix `x` which is the best
84    /// solution to `Ax = b`, i.e. minimizing the 2-norm `||b - Ax||`
85    pub solution: Array<E, I>,
86    /// The rank of the matrix A in `Ax = b`
87    pub rank: i32,
88    /// If n < m and rank(A) == n, the sum of squares
89    /// If b is a (m x 1) vector, this is a 0-dimensional array (single value)
90    /// If b is a (m x k) matrix, this is a (k x 1) column vector
91    pub residual_sum_of_squares: Option<Array<E::Real, I::Smaller>>,
92}
93/// Solve least squares for immutable references
94pub trait LeastSquaresSvd<E, I>
95where
96    E: Scalar + Lapack,
97    I: Dimension,
98{
99    /// Solve a least squares problem of the form `Ax = rhs`
100    /// by calling `A.least_squares(&rhs)`. `A` and `rhs`
101    /// are unchanged.
102    ///
103    /// `A` and `rhs` must have the same layout, i.e. they must
104    /// be both either row- or column-major format, otherwise a
105    /// `IncompatibleShape` error is raised.
106    fn least_squares(&self, rhs: &ArrayRef<E, I>) -> Result<LeastSquaresResult<E, I>>;
107}
108
109/// Solve least squares for owned matrices
110pub trait LeastSquaresSvdInto<D, E, I>
111where
112    D: Data<Elem = E>,
113    E: Scalar + Lapack,
114    I: Dimension,
115{
116    /// Solve a least squares problem of the form `Ax = rhs`
117    /// by calling `A.least_squares(rhs)`, consuming both `A`
118    /// and `rhs`. This uses the memory location of `A` and
119    /// `rhs`, which avoids some extra memory allocations.
120    ///
121    /// `A` and `rhs` must have the same layout, i.e. they must
122    /// be both either row- or column-major format, otherwise a
123    /// `IncompatibleShape` error is raised.
124    fn least_squares_into(self, rhs: ArrayBase<D, I>) -> Result<LeastSquaresResult<E, I>>;
125}
126
127/// Solve least squares for mutable references, overwriting
128/// the input fields in the process
129pub trait LeastSquaresSvdInPlace<E, I>
130where
131    E: Scalar + Lapack,
132    I: Dimension,
133{
134    /// Solve a least squares problem of the form `Ax = rhs`
135    /// by calling `A.least_squares(&mut rhs)`, overwriting both `A`
136    /// and `rhs`. This uses the memory location of `A` and
137    /// `rhs`, which avoids some extra memory allocations.
138    ///
139    /// `A` and `rhs` must have the same layout, i.e. they must
140    /// be both either row- or column-major format, otherwise a
141    /// `IncompatibleShape` error is raised.
142    fn least_squares_in_place(
143        &mut self,
144        rhs: &mut ArrayRef<E, I>,
145    ) -> Result<LeastSquaresResult<E, I>>;
146}
147
148/// Solve least squares for immutable references and a single
149/// column vector as a right-hand side.
150/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
151/// valid representation for `ArrayBase` (over `E`).
152impl<E> LeastSquaresSvd<E, Ix1> for ArrayRef<E, Ix2>
153where
154    E: Scalar + Lapack,
155{
156    /// Solve a least squares problem of the form `Ax = rhs`
157    /// by calling `A.least_squares(&rhs)`, where `rhs` is a
158    /// single column vector. `A` and `rhs` are unchanged.
159    ///
160    /// `A` and `rhs` must have the same layout, i.e. they must
161    /// be both either row- or column-major format, otherwise a
162    /// `IncompatibleShape` error is raised.
163    fn least_squares(&self, rhs: &ArrayRef<E, Ix1>) -> Result<LeastSquaresResult<E, Ix1>> {
164        let a = self.to_owned();
165        let b = rhs.to_owned();
166        a.least_squares_into(b)
167    }
168}
169
170/// Solve least squares for immutable references and matrix
171/// (=mulitipe vectors) as a right-hand side.
172/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
173/// valid representation for `ArrayBase` (over `E`).
174impl<E> LeastSquaresSvd<E, Ix2> for ArrayRef<E, Ix2>
175where
176    E: Scalar + Lapack,
177{
178    /// Solve a least squares problem of the form `Ax = rhs`
179    /// by calling `A.least_squares(&rhs)`, where `rhs` is
180    /// matrix. `A` and `rhs` are unchanged.
181    ///
182    /// `A` and `rhs` must have the same layout, i.e. they must
183    /// be both either row- or column-major format, otherwise a
184    /// `IncompatibleShape` error is raised.
185    fn least_squares(&self, rhs: &ArrayRef<E, Ix2>) -> Result<LeastSquaresResult<E, Ix2>> {
186        let a = self.to_owned();
187        let b = rhs.to_owned();
188        a.least_squares_into(b)
189    }
190}
191
192/// Solve least squares for owned values and a single
193/// column vector as a right-hand side. The matrix and the RHS
194/// vector are consumed.
195///
196/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
197/// valid representation for `ArrayBase`.
198impl<E, D1, D2> LeastSquaresSvdInto<D2, E, Ix1> for ArrayBase<D1, Ix2>
199where
200    E: Scalar + Lapack,
201    D1: DataMut<Elem = E>,
202    D2: DataMut<Elem = E>,
203{
204    /// Solve a least squares problem of the form `Ax = rhs`
205    /// by calling `A.least_squares(rhs)`, where `rhs` is a
206    /// single column vector. `A` and `rhs` are consumed.
207    ///
208    /// `A` and `rhs` must have the same layout, i.e. they must
209    /// be both either row- or column-major format, otherwise a
210    /// `IncompatibleShape` error is raised.
211    fn least_squares_into(
212        mut self,
213        mut rhs: ArrayBase<D2, Ix1>,
214    ) -> Result<LeastSquaresResult<E, Ix1>> {
215        self.least_squares_in_place(&mut rhs)
216    }
217}
218
219/// Solve least squares for owned values and a matrix
220/// as a right-hand side. The matrix and the RHS matrix
221/// are consumed.
222///
223/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
224/// valid representation for `ArrayBase` (over `E`).
225impl<E, D1, D2> LeastSquaresSvdInto<D2, E, Ix2> for ArrayBase<D1, Ix2>
226where
227    E: Scalar + Lapack,
228    D1: DataMut<Elem = E>,
229    D2: DataMut<Elem = E>,
230{
231    /// Solve a least squares problem of the form `Ax = rhs`
232    /// by calling `A.least_squares(rhs)`, where `rhs` is a
233    /// matrix. `A` and `rhs` are consumed.
234    ///
235    /// `A` and `rhs` must have the same layout, i.e. they must
236    /// be both either row- or column-major format, otherwise a
237    /// `IncompatibleShape` error is raised.
238    fn least_squares_into(
239        mut self,
240        mut rhs: ArrayBase<D2, Ix2>,
241    ) -> Result<LeastSquaresResult<E, Ix2>> {
242        self.least_squares_in_place(&mut rhs)
243    }
244}
245
246/// Solve least squares for mutable references and a vector
247/// as a right-hand side. Both values are overwritten in the
248/// call.
249///
250/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
251/// valid representation for `ArrayBase` (over `E`).
252impl<E> LeastSquaresSvdInPlace<E, Ix1> for ArrayRef<E, Ix2>
253where
254    E: Scalar + Lapack,
255{
256    /// Solve a least squares problem of the form `Ax = rhs`
257    /// by calling `A.least_squares(rhs)`, where `rhs` is a
258    /// vector. `A` and `rhs` are overwritten in the call.
259    ///
260    /// `A` and `rhs` must have the same layout, i.e. they must
261    /// be both either row- or column-major format, otherwise a
262    /// `IncompatibleShape` error is raised.
263    fn least_squares_in_place(
264        &mut self,
265        rhs: &mut ArrayRef<E, Ix1>,
266    ) -> Result<LeastSquaresResult<E, Ix1>> {
267        if self.shape()[0] != rhs.shape()[0] {
268            return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into());
269        }
270        let (m, n) = (self.shape()[0], self.shape()[1]);
271        if n > m {
272            // we need a new rhs b/c it will be overwritten with the solution
273            // for which we need `n` entries
274            let mut new_rhs = Array1::<E>::zeros((n,));
275            new_rhs.slice_mut(s![0..m]).assign(rhs);
276            compute_least_squares_srhs(self, &mut new_rhs)
277        } else {
278            compute_least_squares_srhs(self, rhs)
279        }
280    }
281}
282
283fn compute_least_squares_srhs<E>(
284    a: &mut ArrayRef<E, Ix2>,
285    rhs: &mut ArrayRef<E, Ix1>,
286) -> Result<LeastSquaresResult<E, Ix1>>
287where
288    E: Scalar + Lapack,
289{
290    let LeastSquaresOwned::<E> {
291        singular_values,
292        rank,
293    } = E::least_squares(
294        a.layout()?,
295        a.as_allocated_mut()?,
296        rhs.as_slice_memory_order_mut()
297            .ok_or(LinalgError::MemoryNotCont)?,
298    )?;
299
300    let (m, n) = (a.shape()[0], a.shape()[1]);
301    let solution = rhs.slice(s![0..n]).to_owned();
302    let residual_sum_of_squares = compute_residual_scalar(m, n, rank, rhs);
303    Ok(LeastSquaresResult {
304        solution,
305        singular_values: Array::from_shape_vec((singular_values.len(),), singular_values)?,
306        rank,
307        residual_sum_of_squares,
308    })
309}
310
311fn compute_residual_scalar<E: Scalar>(
312    m: usize,
313    n: usize,
314    rank: i32,
315    b: &ArrayRef<E, Ix1>,
316) -> Option<Array<E::Real, Ix0>> {
317    if m < n || n != rank as usize {
318        return None;
319    }
320    let mut arr: Array<E::Real, Ix0> = Array::zeros(());
321    arr[()] = b.slice(s![n..]).mapv(|x| x.powi(2).abs()).sum();
322    Some(arr)
323}
324
325/// Solve least squares for mutable references and a matrix
326/// as a right-hand side. Both values are overwritten in the
327/// call.
328///
329/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
330/// valid representation for `ArrayBase` (over `E`).
331impl<E> LeastSquaresSvdInPlace<E, Ix2> for ArrayRef<E, Ix2>
332where
333    E: Scalar + Lapack,
334{
335    /// Solve a least squares problem of the form `Ax = rhs`
336    /// by calling `A.least_squares(rhs)`, where `rhs` is a
337    /// matrix. `A` and `rhs` are overwritten in the call.
338    ///
339    /// `A` and `rhs` must have the same layout, i.e. they must
340    /// be both either row- or column-major format, otherwise a
341    /// `IncompatibleShape` error is raised.
342    fn least_squares_in_place(
343        &mut self,
344        rhs: &mut ArrayRef<E, Ix2>,
345    ) -> Result<LeastSquaresResult<E, Ix2>> {
346        if self.shape()[0] != rhs.shape()[0] {
347            return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into());
348        }
349        let (m, n) = (self.shape()[0], self.shape()[1]);
350        if n > m {
351            // we need a new rhs b/c it will be overwritten with the solution
352            // for which we need `n` entries
353            let k = rhs.shape()[1];
354            let mut new_rhs = match self.layout()? {
355                MatrixLayout::C { .. } => Array2::<E>::zeros((n, k)),
356                MatrixLayout::F { .. } => Array2::<E>::zeros((n, k).f()),
357            };
358            new_rhs.slice_mut(s![0..m, ..]).assign(rhs);
359            compute_least_squares_nrhs(self, &mut new_rhs)
360        } else {
361            compute_least_squares_nrhs(self, rhs)
362        }
363    }
364}
365
366fn compute_least_squares_nrhs<E>(
367    a: &mut ArrayRef<E, Ix2>,
368    rhs: &mut ArrayRef<E, Ix2>,
369) -> Result<LeastSquaresResult<E, Ix2>>
370where
371    E: Scalar + Lapack,
372{
373    let a_layout = a.layout()?;
374    let rhs_layout = rhs.layout()?;
375    let LeastSquaresOwned::<E> {
376        singular_values,
377        rank,
378    } = E::least_squares_nrhs(
379        a_layout,
380        a.as_allocated_mut()?,
381        rhs_layout,
382        rhs.as_allocated_mut()?,
383    )?;
384
385    let solution: Array2<E> = rhs.slice(s![..a.shape()[1], ..]).to_owned();
386    let singular_values = Array::from_shape_vec((singular_values.len(),), singular_values)?;
387    let (m, n) = (a.shape()[0], a.shape()[1]);
388    let residual_sum_of_squares = compute_residual_array1(m, n, rank, rhs);
389    Ok(LeastSquaresResult {
390        solution,
391        singular_values,
392        rank,
393        residual_sum_of_squares,
394    })
395}
396
397fn compute_residual_array1<E: Scalar>(
398    m: usize,
399    n: usize,
400    rank: i32,
401    b: &ArrayRef<E, Ix2>,
402) -> Option<Array1<E::Real>> {
403    if m < n || n != rank as usize {
404        return None;
405    }
406    Some(
407        b.slice(s![n.., ..])
408            .mapv(|x| x.powi(2).abs())
409            .sum_axis(Axis(0)),
410    )
411}
412
413#[cfg(test)]
414mod tests {
415    use crate::{error::LinalgError, *};
416    use approx::AbsDiffEq;
417    use ndarray::*;
418
419    //
420    // Test that the different least squares traits work as intended on the
421    // different array types.
422    //
423    //               | least_squares | ls_into | ls_in_place |
424    // --------------+---------------+---------+-------------+
425    // Array         | yes           | yes     | yes         |
426    // ArcArray      | yes           | no      | no          |
427    // CowArray      | yes           | yes     | yes         |
428    // ArrayView     | yes           | no      | no          |
429    // ArrayViewMut  | yes           | no      | yes         |
430    //
431
432    fn assert_result<D1: Data<Elem = f64>, D2: Data<Elem = f64>>(
433        a: &ArrayBase<D1, Ix2>,
434        b: &ArrayBase<D2, Ix1>,
435        res: &LeastSquaresResult<f64, Ix1>,
436    ) {
437        assert_eq!(res.rank, 2);
438        let b_hat = a.dot(&res.solution);
439        let rssq = (b - &b_hat).mapv(|x| x.powi(2)).sum();
440        assert!(res.residual_sum_of_squares.as_ref().unwrap()[()].abs_diff_eq(&rssq, 1e-12));
441        assert!(res
442            .solution
443            .abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-12));
444    }
445
446    #[test]
447    fn on_arc() {
448        let a: ArcArray2<f64> = array![[1., 2.], [4., 5.], [3., 4.]].into_shared();
449        let b: ArcArray1<f64> = array![1., 2., 3.].into_shared();
450        let res = a.least_squares(&b).unwrap();
451        assert_result(&a, &b, &res);
452    }
453
454    #[test]
455    fn on_cow() {
456        let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
457        let b = CowArray::from(array![1., 2., 3.]);
458        let res = a.least_squares(&b).unwrap();
459        assert_result(&a, &b, &res);
460    }
461
462    #[test]
463    fn on_view() {
464        let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
465        let b: Array1<f64> = array![1., 2., 3.];
466        let av = a.view();
467        let bv = b.view();
468        let res = av.least_squares(&bv).unwrap();
469        assert_result(&av, &bv, &res);
470    }
471
472    #[test]
473    fn on_view_mut() {
474        let mut a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
475        let mut b: Array1<f64> = array![1., 2., 3.];
476        let av = a.view_mut();
477        let bv = b.view_mut();
478        let res = av.least_squares(&bv).unwrap();
479        assert_result(&av, &bv, &res);
480    }
481
482    #[test]
483    fn on_cow_view() {
484        let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
485        let b: Array1<f64> = array![1., 2., 3.];
486        let bv = b.view();
487        let res = a.least_squares(&bv).unwrap();
488        assert_result(&a, &bv, &res);
489    }
490
491    #[test]
492    fn into_on_owned() {
493        let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
494        let b: Array1<f64> = array![1., 2., 3.];
495        let ac = a.clone();
496        let bc = b.clone();
497        let res = ac.least_squares_into(bc).unwrap();
498        assert_result(&a, &b, &res);
499    }
500
501    #[test]
502    fn into_on_arc() {
503        let a: ArcArray2<f64> = array![[1., 2.], [4., 5.], [3., 4.]].into_shared();
504        let b: ArcArray1<f64> = array![1., 2., 3.].into_shared();
505        let a2 = a.clone();
506        let b2 = b.clone();
507        let res = a2.least_squares_into(b2).unwrap();
508        assert_result(&a, &b, &res);
509    }
510
511    #[test]
512    fn into_on_cow() {
513        let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
514        let b = CowArray::from(array![1., 2., 3.]);
515        let a2 = a.clone();
516        let b2 = b.clone();
517        let res = a2.least_squares_into(b2).unwrap();
518        assert_result(&a, &b, &res);
519    }
520
521    #[test]
522    fn into_on_owned_cow() {
523        let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
524        let b = CowArray::from(array![1., 2., 3.]);
525        let ac = a.clone();
526        let b2 = b.clone();
527        let res = ac.least_squares_into(b2).unwrap();
528        assert_result(&a, &b, &res);
529    }
530
531    #[test]
532    fn in_place_on_owned() {
533        let a = array![[1., 2.], [4., 5.], [3., 4.]];
534        let b = array![1., 2., 3.];
535        let mut a2 = a.clone();
536        let mut b2 = b.clone();
537        let res = a2.least_squares_in_place(&mut b2).unwrap();
538        assert_result(&a, &b, &res);
539    }
540
541    #[test]
542    fn in_place_on_cow() {
543        let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
544        let b = CowArray::from(array![1., 2., 3.]);
545        let mut a2 = a.clone();
546        let mut b2 = b.clone();
547        let res = a2.least_squares_in_place(&mut b2).unwrap();
548        assert_result(&a, &b, &res);
549    }
550
551    #[test]
552    fn in_place_on_mut_view() {
553        let a = array![[1., 2.], [4., 5.], [3., 4.]];
554        let b = array![1., 2., 3.];
555        let mut a2 = a.clone();
556        let mut b2 = b.clone();
557        let av = &mut a2.view_mut();
558        let bv = &mut b2.view_mut();
559        let res = av.least_squares_in_place(bv).unwrap();
560        assert_result(&a, &b, &res);
561    }
562
563    #[test]
564    fn in_place_on_owned_cow() {
565        let a = array![[1., 2.], [4., 5.], [3., 4.]];
566        let b = CowArray::from(array![1., 2., 3.]);
567        let mut a2 = a.clone();
568        let mut b2 = b.clone();
569        let res = a2.least_squares_in_place(&mut b2).unwrap();
570        assert_result(&a, &b, &res);
571    }
572
573    //
574    // Testing error cases
575    //
576    #[test]
577    fn incompatible_shape_error_on_mismatching_num_rows() {
578        let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
579        let b: Array1<f64> = array![1., 2.];
580        match a.least_squares(&b) {
581            Err(LinalgError::Shape(e)) if e.kind() == ErrorKind::IncompatibleShape => {}
582            _ => panic!("Should be raise IncompatibleShape"),
583        }
584    }
585}