ndarray_linalg/least_squares.rs
1//! # Least Squares
2//!
3//! Compute a least-squares solution to the equation Ax = b.
4//! Compute a vector x such that the 2-norm ``|b - A x|`` is minimized.
5//!
6//! Finding the least squares solutions is implemented as traits, meaning
7//! that to solve `A x = b` for a matrix `A` and a RHS `b`, we call
8//! `let result = A.least_squares(&b);`. This returns a `result` of
9//! type `LeastSquaresResult`, the solution for the least square problem
10//! is in `result.solution`.
11//!
12//! There are three traits, `LeastSquaresSvd` with the method `least_squares`,
13//! which operates on immutable references, `LeastSquaresInto` with the method
14//! `least_squares_into`, which takes ownership over both the array `A` and the
15//! RHS `b` and `LeastSquaresSvdInPlace` with the method `least_squares_in_place`,
16//! which operates on mutable references for `A` and `b` and destroys these when
17//! solving the least squares problem. `LeastSquaresSvdInto` and
18//! `LeastSquaresSvdInPlace` avoid an extra allocation for `A` and `b` which
19//! `LeastSquaresSvd` has do perform to preserve the values in `A` and `b`.
20//!
21//! All methods use the Lapacke family of methods `*gelsd` which solves the least
22//! squares problem using the SVD with a divide-and-conquer strategy.
23//!
24//! The traits are implemented for value types `f32`, `f64`, `c32` and `c64`
25//! and vector or matrix right-hand-sides (`ArrayBase<S, Ix1>` or `ArrayBase<S, Ix2>`).
26//!
27//! ## Example
28//! ```rust
29//! use approx::AbsDiffEq; // for abs_diff_eq
30//! use ndarray::{array, Array1, Array2};
31//! use ndarray_linalg::{LeastSquaresSvd, LeastSquaresSvdInto, LeastSquaresSvdInPlace};
32//!
33//! let a: Array2<f64> = array![
34//! [1., 1., 1.],
35//! [2., 3., 4.],
36//! [3., 5., 2.],
37//! [4., 2., 5.],
38//! [5., 4., 3.]
39//! ];
40//! // solving for a single right-hand side
41//! let b: Array1<f64> = array![-10., 12., 14., 16., 18.];
42//! let expected: Array1<f64> = array![2., 1., 1.];
43//! let result = a.least_squares(&b).unwrap();
44//! assert!(result.solution.abs_diff_eq(&expected, 1e-12));
45//!
46//! // solving for two right-hand sides at once
47//! let b_2: Array2<f64> =
48//! array![[-10., -3.], [12., 14.], [14., 12.], [16., 16.], [18., 16.]];
49//! let expected_2: Array2<f64> = array![[2., 1.], [1., 1.], [1., 2.]];
50//! let result_2 = a.least_squares(&b_2).unwrap();
51//! assert!(result_2.solution.abs_diff_eq(&expected_2, 1e-12));
52//!
53//! // using `least_squares_in_place` which overwrites its arguments
54//! let mut a_3 = a.clone();
55//! let mut b_3 = b.clone();
56//! let result_3 = a_3.least_squares_in_place(&mut b_3).unwrap();
57//!
58//! // using `least_squares_into` which consumes its arguments
59//! let result_4 = a.least_squares_into(b).unwrap();
60//! // `a` and `b` have been moved, no longer valid
61//! ```
62
63use lax::*;
64use ndarray::*;
65
66use crate::error::*;
67use crate::layout::*;
68use crate::types::*;
69
70/// Result of a LeastSquares computation
71///
72/// Takes two type parameters, `E`, the element type of the matrix
73/// (one of `f32`, `f64`, `c32` or `c64`) and `I`, the dimension of
74/// b in the equation `Ax = b` (one of `Ix1` or `Ix2`). If `I` is `Ix1`,
75/// the right-hand-side (RHS) is a `n x 1` column vector and the solution
76/// is a `m x 1` column vector. If `I` is `Ix2`, the RHS is a `n x k` matrix
77/// (which can be seen as solving `Ax = b` k times for different b) and
78/// the solution is a `m x k` matrix.
79#[derive(Debug, Clone)]
80pub struct LeastSquaresResult<E: Scalar, I: Dimension> {
81 /// The singular values of the matrix A in `Ax = b`
82 pub singular_values: Array1<E::Real>,
83 /// The solution vector or matrix `x` which is the best
84 /// solution to `Ax = b`, i.e. minimizing the 2-norm `||b - Ax||`
85 pub solution: Array<E, I>,
86 /// The rank of the matrix A in `Ax = b`
87 pub rank: i32,
88 /// If n < m and rank(A) == n, the sum of squares
89 /// If b is a (m x 1) vector, this is a 0-dimensional array (single value)
90 /// If b is a (m x k) matrix, this is a (k x 1) column vector
91 pub residual_sum_of_squares: Option<Array<E::Real, I::Smaller>>,
92}
93/// Solve least squares for immutable references
94pub trait LeastSquaresSvd<E, I>
95where
96 E: Scalar + Lapack,
97 I: Dimension,
98{
99 /// Solve a least squares problem of the form `Ax = rhs`
100 /// by calling `A.least_squares(&rhs)`. `A` and `rhs`
101 /// are unchanged.
102 ///
103 /// `A` and `rhs` must have the same layout, i.e. they must
104 /// be both either row- or column-major format, otherwise a
105 /// `IncompatibleShape` error is raised.
106 fn least_squares(&self, rhs: &ArrayRef<E, I>) -> Result<LeastSquaresResult<E, I>>;
107}
108
109/// Solve least squares for owned matrices
110pub trait LeastSquaresSvdInto<D, E, I>
111where
112 D: Data<Elem = E>,
113 E: Scalar + Lapack,
114 I: Dimension,
115{
116 /// Solve a least squares problem of the form `Ax = rhs`
117 /// by calling `A.least_squares(rhs)`, consuming both `A`
118 /// and `rhs`. This uses the memory location of `A` and
119 /// `rhs`, which avoids some extra memory allocations.
120 ///
121 /// `A` and `rhs` must have the same layout, i.e. they must
122 /// be both either row- or column-major format, otherwise a
123 /// `IncompatibleShape` error is raised.
124 fn least_squares_into(self, rhs: ArrayBase<D, I>) -> Result<LeastSquaresResult<E, I>>;
125}
126
127/// Solve least squares for mutable references, overwriting
128/// the input fields in the process
129pub trait LeastSquaresSvdInPlace<E, I>
130where
131 E: Scalar + Lapack,
132 I: Dimension,
133{
134 /// Solve a least squares problem of the form `Ax = rhs`
135 /// by calling `A.least_squares(&mut rhs)`, overwriting both `A`
136 /// and `rhs`. This uses the memory location of `A` and
137 /// `rhs`, which avoids some extra memory allocations.
138 ///
139 /// `A` and `rhs` must have the same layout, i.e. they must
140 /// be both either row- or column-major format, otherwise a
141 /// `IncompatibleShape` error is raised.
142 fn least_squares_in_place(
143 &mut self,
144 rhs: &mut ArrayRef<E, I>,
145 ) -> Result<LeastSquaresResult<E, I>>;
146}
147
148/// Solve least squares for immutable references and a single
149/// column vector as a right-hand side.
150/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
151/// valid representation for `ArrayBase` (over `E`).
152impl<E> LeastSquaresSvd<E, Ix1> for ArrayRef<E, Ix2>
153where
154 E: Scalar + Lapack,
155{
156 /// Solve a least squares problem of the form `Ax = rhs`
157 /// by calling `A.least_squares(&rhs)`, where `rhs` is a
158 /// single column vector. `A` and `rhs` are unchanged.
159 ///
160 /// `A` and `rhs` must have the same layout, i.e. they must
161 /// be both either row- or column-major format, otherwise a
162 /// `IncompatibleShape` error is raised.
163 fn least_squares(&self, rhs: &ArrayRef<E, Ix1>) -> Result<LeastSquaresResult<E, Ix1>> {
164 let a = self.to_owned();
165 let b = rhs.to_owned();
166 a.least_squares_into(b)
167 }
168}
169
170/// Solve least squares for immutable references and matrix
171/// (=mulitipe vectors) as a right-hand side.
172/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
173/// valid representation for `ArrayBase` (over `E`).
174impl<E> LeastSquaresSvd<E, Ix2> for ArrayRef<E, Ix2>
175where
176 E: Scalar + Lapack,
177{
178 /// Solve a least squares problem of the form `Ax = rhs`
179 /// by calling `A.least_squares(&rhs)`, where `rhs` is
180 /// matrix. `A` and `rhs` are unchanged.
181 ///
182 /// `A` and `rhs` must have the same layout, i.e. they must
183 /// be both either row- or column-major format, otherwise a
184 /// `IncompatibleShape` error is raised.
185 fn least_squares(&self, rhs: &ArrayRef<E, Ix2>) -> Result<LeastSquaresResult<E, Ix2>> {
186 let a = self.to_owned();
187 let b = rhs.to_owned();
188 a.least_squares_into(b)
189 }
190}
191
192/// Solve least squares for owned values and a single
193/// column vector as a right-hand side. The matrix and the RHS
194/// vector are consumed.
195///
196/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
197/// valid representation for `ArrayBase`.
198impl<E, D1, D2> LeastSquaresSvdInto<D2, E, Ix1> for ArrayBase<D1, Ix2>
199where
200 E: Scalar + Lapack,
201 D1: DataMut<Elem = E>,
202 D2: DataMut<Elem = E>,
203{
204 /// Solve a least squares problem of the form `Ax = rhs`
205 /// by calling `A.least_squares(rhs)`, where `rhs` is a
206 /// single column vector. `A` and `rhs` are consumed.
207 ///
208 /// `A` and `rhs` must have the same layout, i.e. they must
209 /// be both either row- or column-major format, otherwise a
210 /// `IncompatibleShape` error is raised.
211 fn least_squares_into(
212 mut self,
213 mut rhs: ArrayBase<D2, Ix1>,
214 ) -> Result<LeastSquaresResult<E, Ix1>> {
215 self.least_squares_in_place(&mut rhs)
216 }
217}
218
219/// Solve least squares for owned values and a matrix
220/// as a right-hand side. The matrix and the RHS matrix
221/// are consumed.
222///
223/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
224/// valid representation for `ArrayBase` (over `E`).
225impl<E, D1, D2> LeastSquaresSvdInto<D2, E, Ix2> for ArrayBase<D1, Ix2>
226where
227 E: Scalar + Lapack,
228 D1: DataMut<Elem = E>,
229 D2: DataMut<Elem = E>,
230{
231 /// Solve a least squares problem of the form `Ax = rhs`
232 /// by calling `A.least_squares(rhs)`, where `rhs` is a
233 /// matrix. `A` and `rhs` are consumed.
234 ///
235 /// `A` and `rhs` must have the same layout, i.e. they must
236 /// be both either row- or column-major format, otherwise a
237 /// `IncompatibleShape` error is raised.
238 fn least_squares_into(
239 mut self,
240 mut rhs: ArrayBase<D2, Ix2>,
241 ) -> Result<LeastSquaresResult<E, Ix2>> {
242 self.least_squares_in_place(&mut rhs)
243 }
244}
245
246/// Solve least squares for mutable references and a vector
247/// as a right-hand side. Both values are overwritten in the
248/// call.
249///
250/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
251/// valid representation for `ArrayBase` (over `E`).
252impl<E> LeastSquaresSvdInPlace<E, Ix1> for ArrayRef<E, Ix2>
253where
254 E: Scalar + Lapack,
255{
256 /// Solve a least squares problem of the form `Ax = rhs`
257 /// by calling `A.least_squares(rhs)`, where `rhs` is a
258 /// vector. `A` and `rhs` are overwritten in the call.
259 ///
260 /// `A` and `rhs` must have the same layout, i.e. they must
261 /// be both either row- or column-major format, otherwise a
262 /// `IncompatibleShape` error is raised.
263 fn least_squares_in_place(
264 &mut self,
265 rhs: &mut ArrayRef<E, Ix1>,
266 ) -> Result<LeastSquaresResult<E, Ix1>> {
267 if self.shape()[0] != rhs.shape()[0] {
268 return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into());
269 }
270 let (m, n) = (self.shape()[0], self.shape()[1]);
271 if n > m {
272 // we need a new rhs b/c it will be overwritten with the solution
273 // for which we need `n` entries
274 let mut new_rhs = Array1::<E>::zeros((n,));
275 new_rhs.slice_mut(s![0..m]).assign(rhs);
276 compute_least_squares_srhs(self, &mut new_rhs)
277 } else {
278 compute_least_squares_srhs(self, rhs)
279 }
280 }
281}
282
283fn compute_least_squares_srhs<E>(
284 a: &mut ArrayRef<E, Ix2>,
285 rhs: &mut ArrayRef<E, Ix1>,
286) -> Result<LeastSquaresResult<E, Ix1>>
287where
288 E: Scalar + Lapack,
289{
290 let LeastSquaresOwned::<E> {
291 singular_values,
292 rank,
293 } = E::least_squares(
294 a.layout()?,
295 a.as_allocated_mut()?,
296 rhs.as_slice_memory_order_mut()
297 .ok_or(LinalgError::MemoryNotCont)?,
298 )?;
299
300 let (m, n) = (a.shape()[0], a.shape()[1]);
301 let solution = rhs.slice(s![0..n]).to_owned();
302 let residual_sum_of_squares = compute_residual_scalar(m, n, rank, rhs);
303 Ok(LeastSquaresResult {
304 solution,
305 singular_values: Array::from_shape_vec((singular_values.len(),), singular_values)?,
306 rank,
307 residual_sum_of_squares,
308 })
309}
310
311fn compute_residual_scalar<E: Scalar>(
312 m: usize,
313 n: usize,
314 rank: i32,
315 b: &ArrayRef<E, Ix1>,
316) -> Option<Array<E::Real, Ix0>> {
317 if m < n || n != rank as usize {
318 return None;
319 }
320 let mut arr: Array<E::Real, Ix0> = Array::zeros(());
321 arr[()] = b.slice(s![n..]).mapv(|x| x.powi(2).abs()).sum();
322 Some(arr)
323}
324
325/// Solve least squares for mutable references and a matrix
326/// as a right-hand side. Both values are overwritten in the
327/// call.
328///
329/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
330/// valid representation for `ArrayBase` (over `E`).
331impl<E> LeastSquaresSvdInPlace<E, Ix2> for ArrayRef<E, Ix2>
332where
333 E: Scalar + Lapack,
334{
335 /// Solve a least squares problem of the form `Ax = rhs`
336 /// by calling `A.least_squares(rhs)`, where `rhs` is a
337 /// matrix. `A` and `rhs` are overwritten in the call.
338 ///
339 /// `A` and `rhs` must have the same layout, i.e. they must
340 /// be both either row- or column-major format, otherwise a
341 /// `IncompatibleShape` error is raised.
342 fn least_squares_in_place(
343 &mut self,
344 rhs: &mut ArrayRef<E, Ix2>,
345 ) -> Result<LeastSquaresResult<E, Ix2>> {
346 if self.shape()[0] != rhs.shape()[0] {
347 return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into());
348 }
349 let (m, n) = (self.shape()[0], self.shape()[1]);
350 if n > m {
351 // we need a new rhs b/c it will be overwritten with the solution
352 // for which we need `n` entries
353 let k = rhs.shape()[1];
354 let mut new_rhs = match self.layout()? {
355 MatrixLayout::C { .. } => Array2::<E>::zeros((n, k)),
356 MatrixLayout::F { .. } => Array2::<E>::zeros((n, k).f()),
357 };
358 new_rhs.slice_mut(s![0..m, ..]).assign(rhs);
359 compute_least_squares_nrhs(self, &mut new_rhs)
360 } else {
361 compute_least_squares_nrhs(self, rhs)
362 }
363 }
364}
365
366fn compute_least_squares_nrhs<E>(
367 a: &mut ArrayRef<E, Ix2>,
368 rhs: &mut ArrayRef<E, Ix2>,
369) -> Result<LeastSquaresResult<E, Ix2>>
370where
371 E: Scalar + Lapack,
372{
373 let a_layout = a.layout()?;
374 let rhs_layout = rhs.layout()?;
375 let LeastSquaresOwned::<E> {
376 singular_values,
377 rank,
378 } = E::least_squares_nrhs(
379 a_layout,
380 a.as_allocated_mut()?,
381 rhs_layout,
382 rhs.as_allocated_mut()?,
383 )?;
384
385 let solution: Array2<E> = rhs.slice(s![..a.shape()[1], ..]).to_owned();
386 let singular_values = Array::from_shape_vec((singular_values.len(),), singular_values)?;
387 let (m, n) = (a.shape()[0], a.shape()[1]);
388 let residual_sum_of_squares = compute_residual_array1(m, n, rank, rhs);
389 Ok(LeastSquaresResult {
390 solution,
391 singular_values,
392 rank,
393 residual_sum_of_squares,
394 })
395}
396
397fn compute_residual_array1<E: Scalar>(
398 m: usize,
399 n: usize,
400 rank: i32,
401 b: &ArrayRef<E, Ix2>,
402) -> Option<Array1<E::Real>> {
403 if m < n || n != rank as usize {
404 return None;
405 }
406 Some(
407 b.slice(s![n.., ..])
408 .mapv(|x| x.powi(2).abs())
409 .sum_axis(Axis(0)),
410 )
411}
412
413#[cfg(test)]
414mod tests {
415 use crate::{error::LinalgError, *};
416 use approx::AbsDiffEq;
417 use ndarray::*;
418
419 //
420 // Test that the different least squares traits work as intended on the
421 // different array types.
422 //
423 // | least_squares | ls_into | ls_in_place |
424 // --------------+---------------+---------+-------------+
425 // Array | yes | yes | yes |
426 // ArcArray | yes | no | no |
427 // CowArray | yes | yes | yes |
428 // ArrayView | yes | no | no |
429 // ArrayViewMut | yes | no | yes |
430 //
431
432 fn assert_result<D1: Data<Elem = f64>, D2: Data<Elem = f64>>(
433 a: &ArrayBase<D1, Ix2>,
434 b: &ArrayBase<D2, Ix1>,
435 res: &LeastSquaresResult<f64, Ix1>,
436 ) {
437 assert_eq!(res.rank, 2);
438 let b_hat = a.dot(&res.solution);
439 let rssq = (b - &b_hat).mapv(|x| x.powi(2)).sum();
440 assert!(res.residual_sum_of_squares.as_ref().unwrap()[()].abs_diff_eq(&rssq, 1e-12));
441 assert!(res
442 .solution
443 .abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-12));
444 }
445
446 #[test]
447 fn on_arc() {
448 let a: ArcArray2<f64> = array![[1., 2.], [4., 5.], [3., 4.]].into_shared();
449 let b: ArcArray1<f64> = array![1., 2., 3.].into_shared();
450 let res = a.least_squares(&b).unwrap();
451 assert_result(&a, &b, &res);
452 }
453
454 #[test]
455 fn on_cow() {
456 let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
457 let b = CowArray::from(array![1., 2., 3.]);
458 let res = a.least_squares(&b).unwrap();
459 assert_result(&a, &b, &res);
460 }
461
462 #[test]
463 fn on_view() {
464 let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
465 let b: Array1<f64> = array![1., 2., 3.];
466 let av = a.view();
467 let bv = b.view();
468 let res = av.least_squares(&bv).unwrap();
469 assert_result(&av, &bv, &res);
470 }
471
472 #[test]
473 fn on_view_mut() {
474 let mut a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
475 let mut b: Array1<f64> = array![1., 2., 3.];
476 let av = a.view_mut();
477 let bv = b.view_mut();
478 let res = av.least_squares(&bv).unwrap();
479 assert_result(&av, &bv, &res);
480 }
481
482 #[test]
483 fn on_cow_view() {
484 let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
485 let b: Array1<f64> = array![1., 2., 3.];
486 let bv = b.view();
487 let res = a.least_squares(&bv).unwrap();
488 assert_result(&a, &bv, &res);
489 }
490
491 #[test]
492 fn into_on_owned() {
493 let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
494 let b: Array1<f64> = array![1., 2., 3.];
495 let ac = a.clone();
496 let bc = b.clone();
497 let res = ac.least_squares_into(bc).unwrap();
498 assert_result(&a, &b, &res);
499 }
500
501 #[test]
502 fn into_on_arc() {
503 let a: ArcArray2<f64> = array![[1., 2.], [4., 5.], [3., 4.]].into_shared();
504 let b: ArcArray1<f64> = array![1., 2., 3.].into_shared();
505 let a2 = a.clone();
506 let b2 = b.clone();
507 let res = a2.least_squares_into(b2).unwrap();
508 assert_result(&a, &b, &res);
509 }
510
511 #[test]
512 fn into_on_cow() {
513 let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
514 let b = CowArray::from(array![1., 2., 3.]);
515 let a2 = a.clone();
516 let b2 = b.clone();
517 let res = a2.least_squares_into(b2).unwrap();
518 assert_result(&a, &b, &res);
519 }
520
521 #[test]
522 fn into_on_owned_cow() {
523 let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
524 let b = CowArray::from(array![1., 2., 3.]);
525 let ac = a.clone();
526 let b2 = b.clone();
527 let res = ac.least_squares_into(b2).unwrap();
528 assert_result(&a, &b, &res);
529 }
530
531 #[test]
532 fn in_place_on_owned() {
533 let a = array![[1., 2.], [4., 5.], [3., 4.]];
534 let b = array![1., 2., 3.];
535 let mut a2 = a.clone();
536 let mut b2 = b.clone();
537 let res = a2.least_squares_in_place(&mut b2).unwrap();
538 assert_result(&a, &b, &res);
539 }
540
541 #[test]
542 fn in_place_on_cow() {
543 let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
544 let b = CowArray::from(array![1., 2., 3.]);
545 let mut a2 = a.clone();
546 let mut b2 = b.clone();
547 let res = a2.least_squares_in_place(&mut b2).unwrap();
548 assert_result(&a, &b, &res);
549 }
550
551 #[test]
552 fn in_place_on_mut_view() {
553 let a = array![[1., 2.], [4., 5.], [3., 4.]];
554 let b = array![1., 2., 3.];
555 let mut a2 = a.clone();
556 let mut b2 = b.clone();
557 let av = &mut a2.view_mut();
558 let bv = &mut b2.view_mut();
559 let res = av.least_squares_in_place(bv).unwrap();
560 assert_result(&a, &b, &res);
561 }
562
563 #[test]
564 fn in_place_on_owned_cow() {
565 let a = array![[1., 2.], [4., 5.], [3., 4.]];
566 let b = CowArray::from(array![1., 2., 3.]);
567 let mut a2 = a.clone();
568 let mut b2 = b.clone();
569 let res = a2.least_squares_in_place(&mut b2).unwrap();
570 assert_result(&a, &b, &res);
571 }
572
573 //
574 // Testing error cases
575 //
576 #[test]
577 fn incompatible_shape_error_on_mismatching_num_rows() {
578 let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
579 let b: Array1<f64> = array![1., 2.];
580 match a.least_squares(&b) {
581 Err(LinalgError::Shape(e)) if e.kind() == ErrorKind::IncompatibleShape => {}
582 _ => panic!("Should be raise IncompatibleShape"),
583 }
584 }
585}