ndarray_linalg/least_squares.rs
1//! # Least Squares
2//!
3//! Compute a least-squares solution to the equation Ax = b.
4//! Compute a vector x such that the 2-norm ``|b - A x|`` is minimized.
5//!
6//! Finding the least squares solutions is implemented as traits, meaning
7//! that to solve `A x = b` for a matrix `A` and a RHS `b`, we call
8//! `let result = A.least_squares(&b);`. This returns a `result` of
9//! type `LeastSquaresResult`, the solution for the least square problem
10//! is in `result.solution`.
11//!
12//! There are three traits, `LeastSquaresSvd` with the method `least_squares`,
13//! which operates on immutable references, `LeastSquaresInto` with the method
14//! `least_squares_into`, which takes ownership over both the array `A` and the
15//! RHS `b` and `LeastSquaresSvdInPlace` with the method `least_squares_in_place`,
16//! which operates on mutable references for `A` and `b` and destroys these when
17//! solving the least squares problem. `LeastSquaresSvdInto` and
18//! `LeastSquaresSvdInPlace` avoid an extra allocation for `A` and `b` which
19//! `LeastSquaresSvd` has do perform to preserve the values in `A` and `b`.
20//!
21//! All methods use the Lapacke family of methods `*gelsd` which solves the least
22//! squares problem using the SVD with a divide-and-conquer strategy.
23//!
24//! The traits are implemented for value types `f32`, `f64`, `c32` and `c64`
25//! and vector or matrix right-hand-sides (`ArrayBase<S, Ix1>` or `ArrayBase<S, Ix2>`).
26//!
27//! ## Example
28//! ```rust
29//! use approx::AbsDiffEq; // for abs_diff_eq
30//! use ndarray::{array, Array1, Array2};
31//! use ndarray_linalg::{LeastSquaresSvd, LeastSquaresSvdInto, LeastSquaresSvdInPlace};
32//!
33//! let a: Array2<f64> = array![
34//! [1., 1., 1.],
35//! [2., 3., 4.],
36//! [3., 5., 2.],
37//! [4., 2., 5.],
38//! [5., 4., 3.]
39//! ];
40//! // solving for a single right-hand side
41//! let b: Array1<f64> = array![-10., 12., 14., 16., 18.];
42//! let expected: Array1<f64> = array![2., 1., 1.];
43//! let result = a.least_squares(&b).unwrap();
44//! assert!(result.solution.abs_diff_eq(&expected, 1e-12));
45//!
46//! // solving for two right-hand sides at once
47//! let b_2: Array2<f64> =
48//! array![[-10., -3.], [12., 14.], [14., 12.], [16., 16.], [18., 16.]];
49//! let expected_2: Array2<f64> = array![[2., 1.], [1., 1.], [1., 2.]];
50//! let result_2 = a.least_squares(&b_2).unwrap();
51//! assert!(result_2.solution.abs_diff_eq(&expected_2, 1e-12));
52//!
53//! // using `least_squares_in_place` which overwrites its arguments
54//! let mut a_3 = a.clone();
55//! let mut b_3 = b.clone();
56//! let result_3 = a_3.least_squares_in_place(&mut b_3).unwrap();
57//!
58//! // using `least_squares_into` which consumes its arguments
59//! let result_4 = a.least_squares_into(b).unwrap();
60//! // `a` and `b` have been moved, no longer valid
61//! ```
62
63use lax::*;
64use ndarray::*;
65
66use crate::error::*;
67use crate::layout::*;
68use crate::types::*;
69
70/// Result of a LeastSquares computation
71///
72/// Takes two type parameters, `E`, the element type of the matrix
73/// (one of `f32`, `f64`, `c32` or `c64`) and `I`, the dimension of
74/// b in the equation `Ax = b` (one of `Ix1` or `Ix2`). If `I` is `Ix1`,
75/// the right-hand-side (RHS) is a `n x 1` column vector and the solution
76/// is a `m x 1` column vector. If `I` is `Ix2`, the RHS is a `n x k` matrix
77/// (which can be seen as solving `Ax = b` k times for different b) and
78/// the solution is a `m x k` matrix.
79#[derive(Debug, Clone)]
80pub struct LeastSquaresResult<E: Scalar, I: Dimension> {
81 /// The singular values of the matrix A in `Ax = b`
82 pub singular_values: Array1<E::Real>,
83 /// The solution vector or matrix `x` which is the best
84 /// solution to `Ax = b`, i.e. minimizing the 2-norm `||b - Ax||`
85 pub solution: Array<E, I>,
86 /// The rank of the matrix A in `Ax = b`
87 pub rank: i32,
88 /// If n < m and rank(A) == n, the sum of squares
89 /// If b is a (m x 1) vector, this is a 0-dimensional array (single value)
90 /// If b is a (m x k) matrix, this is a (k x 1) column vector
91 pub residual_sum_of_squares: Option<Array<E::Real, I::Smaller>>,
92}
93/// Solve least squares for immutable references
94pub trait LeastSquaresSvd<D, E, I>
95where
96 D: Data<Elem = E>,
97 E: Scalar + Lapack,
98 I: Dimension,
99{
100 /// Solve a least squares problem of the form `Ax = rhs`
101 /// by calling `A.least_squares(&rhs)`. `A` and `rhs`
102 /// are unchanged.
103 ///
104 /// `A` and `rhs` must have the same layout, i.e. they must
105 /// be both either row- or column-major format, otherwise a
106 /// `IncompatibleShape` error is raised.
107 fn least_squares(&self, rhs: &ArrayBase<D, I>) -> Result<LeastSquaresResult<E, I>>;
108}
109
110/// Solve least squares for owned matrices
111pub trait LeastSquaresSvdInto<D, E, I>
112where
113 D: Data<Elem = E>,
114 E: Scalar + Lapack,
115 I: Dimension,
116{
117 /// Solve a least squares problem of the form `Ax = rhs`
118 /// by calling `A.least_squares(rhs)`, consuming both `A`
119 /// and `rhs`. This uses the memory location of `A` and
120 /// `rhs`, which avoids some extra memory allocations.
121 ///
122 /// `A` and `rhs` must have the same layout, i.e. they must
123 /// be both either row- or column-major format, otherwise a
124 /// `IncompatibleShape` error is raised.
125 fn least_squares_into(self, rhs: ArrayBase<D, I>) -> Result<LeastSquaresResult<E, I>>;
126}
127
128/// Solve least squares for mutable references, overwriting
129/// the input fields in the process
130pub trait LeastSquaresSvdInPlace<D, E, I>
131where
132 D: Data<Elem = E>,
133 E: Scalar + Lapack,
134 I: Dimension,
135{
136 /// Solve a least squares problem of the form `Ax = rhs`
137 /// by calling `A.least_squares(&mut rhs)`, overwriting both `A`
138 /// and `rhs`. This uses the memory location of `A` and
139 /// `rhs`, which avoids some extra memory allocations.
140 ///
141 /// `A` and `rhs` must have the same layout, i.e. they must
142 /// be both either row- or column-major format, otherwise a
143 /// `IncompatibleShape` error is raised.
144 fn least_squares_in_place(
145 &mut self,
146 rhs: &mut ArrayBase<D, I>,
147 ) -> Result<LeastSquaresResult<E, I>>;
148}
149
150/// Solve least squares for immutable references and a single
151/// column vector as a right-hand side.
152/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
153/// valid representation for `ArrayBase` (over `E`).
154impl<E, D1, D2> LeastSquaresSvd<D2, E, Ix1> for ArrayBase<D1, Ix2>
155where
156 E: Scalar + Lapack,
157 D1: Data<Elem = E>,
158 D2: Data<Elem = E>,
159{
160 /// Solve a least squares problem of the form `Ax = rhs`
161 /// by calling `A.least_squares(&rhs)`, where `rhs` is a
162 /// single column vector. `A` and `rhs` are unchanged.
163 ///
164 /// `A` and `rhs` must have the same layout, i.e. they must
165 /// be both either row- or column-major format, otherwise a
166 /// `IncompatibleShape` error is raised.
167 fn least_squares(&self, rhs: &ArrayBase<D2, Ix1>) -> Result<LeastSquaresResult<E, Ix1>> {
168 let a = self.to_owned();
169 let b = rhs.to_owned();
170 a.least_squares_into(b)
171 }
172}
173
174/// Solve least squares for immutable references and matrix
175/// (=mulitipe vectors) as a right-hand side.
176/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
177/// valid representation for `ArrayBase` (over `E`).
178impl<E, D1, D2> LeastSquaresSvd<D2, E, Ix2> for ArrayBase<D1, Ix2>
179where
180 E: Scalar + Lapack,
181 D1: Data<Elem = E>,
182 D2: Data<Elem = E>,
183{
184 /// Solve a least squares problem of the form `Ax = rhs`
185 /// by calling `A.least_squares(&rhs)`, where `rhs` is
186 /// matrix. `A` and `rhs` are unchanged.
187 ///
188 /// `A` and `rhs` must have the same layout, i.e. they must
189 /// be both either row- or column-major format, otherwise a
190 /// `IncompatibleShape` error is raised.
191 fn least_squares(&self, rhs: &ArrayBase<D2, Ix2>) -> Result<LeastSquaresResult<E, Ix2>> {
192 let a = self.to_owned();
193 let b = rhs.to_owned();
194 a.least_squares_into(b)
195 }
196}
197
198/// Solve least squares for owned values and a single
199/// column vector as a right-hand side. The matrix and the RHS
200/// vector are consumed.
201///
202/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
203/// valid representation for `ArrayBase`.
204impl<E, D1, D2> LeastSquaresSvdInto<D2, E, Ix1> for ArrayBase<D1, Ix2>
205where
206 E: Scalar + Lapack,
207 D1: DataMut<Elem = E>,
208 D2: DataMut<Elem = E>,
209{
210 /// Solve a least squares problem of the form `Ax = rhs`
211 /// by calling `A.least_squares(rhs)`, where `rhs` is a
212 /// single column vector. `A` and `rhs` are consumed.
213 ///
214 /// `A` and `rhs` must have the same layout, i.e. they must
215 /// be both either row- or column-major format, otherwise a
216 /// `IncompatibleShape` error is raised.
217 fn least_squares_into(
218 mut self,
219 mut rhs: ArrayBase<D2, Ix1>,
220 ) -> Result<LeastSquaresResult<E, Ix1>> {
221 self.least_squares_in_place(&mut rhs)
222 }
223}
224
225/// Solve least squares for owned values and a matrix
226/// as a right-hand side. The matrix and the RHS matrix
227/// are consumed.
228///
229/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
230/// valid representation for `ArrayBase` (over `E`).
231impl<E, D1, D2> LeastSquaresSvdInto<D2, E, Ix2> for ArrayBase<D1, Ix2>
232where
233 E: Scalar + Lapack,
234 D1: DataMut<Elem = E>,
235 D2: DataMut<Elem = E>,
236{
237 /// Solve a least squares problem of the form `Ax = rhs`
238 /// by calling `A.least_squares(rhs)`, where `rhs` is a
239 /// matrix. `A` and `rhs` are consumed.
240 ///
241 /// `A` and `rhs` must have the same layout, i.e. they must
242 /// be both either row- or column-major format, otherwise a
243 /// `IncompatibleShape` error is raised.
244 fn least_squares_into(
245 mut self,
246 mut rhs: ArrayBase<D2, Ix2>,
247 ) -> Result<LeastSquaresResult<E, Ix2>> {
248 self.least_squares_in_place(&mut rhs)
249 }
250}
251
252/// Solve least squares for mutable references and a vector
253/// as a right-hand side. Both values are overwritten in the
254/// call.
255///
256/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
257/// valid representation for `ArrayBase` (over `E`).
258impl<E, D1, D2> LeastSquaresSvdInPlace<D2, E, Ix1> for ArrayBase<D1, Ix2>
259where
260 E: Scalar + Lapack,
261 D1: DataMut<Elem = E>,
262 D2: DataMut<Elem = E>,
263{
264 /// Solve a least squares problem of the form `Ax = rhs`
265 /// by calling `A.least_squares(rhs)`, where `rhs` is a
266 /// vector. `A` and `rhs` are overwritten in the call.
267 ///
268 /// `A` and `rhs` must have the same layout, i.e. they must
269 /// be both either row- or column-major format, otherwise a
270 /// `IncompatibleShape` error is raised.
271 fn least_squares_in_place(
272 &mut self,
273 rhs: &mut ArrayBase<D2, Ix1>,
274 ) -> Result<LeastSquaresResult<E, Ix1>> {
275 if self.shape()[0] != rhs.shape()[0] {
276 return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into());
277 }
278 let (m, n) = (self.shape()[0], self.shape()[1]);
279 if n > m {
280 // we need a new rhs b/c it will be overwritten with the solution
281 // for which we need `n` entries
282 let mut new_rhs = Array1::<E>::zeros((n,));
283 new_rhs.slice_mut(s![0..m]).assign(rhs);
284 compute_least_squares_srhs(self, &mut new_rhs)
285 } else {
286 compute_least_squares_srhs(self, rhs)
287 }
288 }
289}
290
291fn compute_least_squares_srhs<E, D1, D2>(
292 a: &mut ArrayBase<D1, Ix2>,
293 rhs: &mut ArrayBase<D2, Ix1>,
294) -> Result<LeastSquaresResult<E, Ix1>>
295where
296 E: Scalar + Lapack,
297 D1: DataMut<Elem = E>,
298 D2: DataMut<Elem = E>,
299{
300 let LeastSquaresOwned::<E> {
301 singular_values,
302 rank,
303 } = E::least_squares(
304 a.layout()?,
305 a.as_allocated_mut()?,
306 rhs.as_slice_memory_order_mut()
307 .ok_or(LinalgError::MemoryNotCont)?,
308 )?;
309
310 let (m, n) = (a.shape()[0], a.shape()[1]);
311 let solution = rhs.slice(s![0..n]).to_owned();
312 let residual_sum_of_squares = compute_residual_scalar(m, n, rank, rhs);
313 Ok(LeastSquaresResult {
314 solution,
315 singular_values: Array::from_shape_vec((singular_values.len(),), singular_values)?,
316 rank,
317 residual_sum_of_squares,
318 })
319}
320
321fn compute_residual_scalar<E: Scalar, D: Data<Elem = E>>(
322 m: usize,
323 n: usize,
324 rank: i32,
325 b: &ArrayBase<D, Ix1>,
326) -> Option<Array<E::Real, Ix0>> {
327 if m < n || n != rank as usize {
328 return None;
329 }
330 let mut arr: Array<E::Real, Ix0> = Array::zeros(());
331 arr[()] = b.slice(s![n..]).mapv(|x| x.powi(2).abs()).sum();
332 Some(arr)
333}
334
335/// Solve least squares for mutable references and a matrix
336/// as a right-hand side. Both values are overwritten in the
337/// call.
338///
339/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
340/// valid representation for `ArrayBase` (over `E`).
341impl<E, D1, D2> LeastSquaresSvdInPlace<D2, E, Ix2> for ArrayBase<D1, Ix2>
342where
343 E: Scalar + Lapack,
344 D1: DataMut<Elem = E>,
345 D2: DataMut<Elem = E>,
346{
347 /// Solve a least squares problem of the form `Ax = rhs`
348 /// by calling `A.least_squares(rhs)`, where `rhs` is a
349 /// matrix. `A` and `rhs` are overwritten in the call.
350 ///
351 /// `A` and `rhs` must have the same layout, i.e. they must
352 /// be both either row- or column-major format, otherwise a
353 /// `IncompatibleShape` error is raised.
354 fn least_squares_in_place(
355 &mut self,
356 rhs: &mut ArrayBase<D2, Ix2>,
357 ) -> Result<LeastSquaresResult<E, Ix2>> {
358 if self.shape()[0] != rhs.shape()[0] {
359 return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into());
360 }
361 let (m, n) = (self.shape()[0], self.shape()[1]);
362 if n > m {
363 // we need a new rhs b/c it will be overwritten with the solution
364 // for which we need `n` entries
365 let k = rhs.shape()[1];
366 let mut new_rhs = match self.layout()? {
367 MatrixLayout::C { .. } => Array2::<E>::zeros((n, k)),
368 MatrixLayout::F { .. } => Array2::<E>::zeros((n, k).f()),
369 };
370 new_rhs.slice_mut(s![0..m, ..]).assign(rhs);
371 compute_least_squares_nrhs(self, &mut new_rhs)
372 } else {
373 compute_least_squares_nrhs(self, rhs)
374 }
375 }
376}
377
378fn compute_least_squares_nrhs<E, D1, D2>(
379 a: &mut ArrayBase<D1, Ix2>,
380 rhs: &mut ArrayBase<D2, Ix2>,
381) -> Result<LeastSquaresResult<E, Ix2>>
382where
383 E: Scalar + Lapack,
384 D1: DataMut<Elem = E>,
385 D2: DataMut<Elem = E>,
386{
387 let a_layout = a.layout()?;
388 let rhs_layout = rhs.layout()?;
389 let LeastSquaresOwned::<E> {
390 singular_values,
391 rank,
392 } = E::least_squares_nrhs(
393 a_layout,
394 a.as_allocated_mut()?,
395 rhs_layout,
396 rhs.as_allocated_mut()?,
397 )?;
398
399 let solution: Array2<E> = rhs.slice(s![..a.shape()[1], ..]).to_owned();
400 let singular_values = Array::from_shape_vec((singular_values.len(),), singular_values)?;
401 let (m, n) = (a.shape()[0], a.shape()[1]);
402 let residual_sum_of_squares = compute_residual_array1(m, n, rank, rhs);
403 Ok(LeastSquaresResult {
404 solution,
405 singular_values,
406 rank,
407 residual_sum_of_squares,
408 })
409}
410
411fn compute_residual_array1<E: Scalar, D: Data<Elem = E>>(
412 m: usize,
413 n: usize,
414 rank: i32,
415 b: &ArrayBase<D, Ix2>,
416) -> Option<Array1<E::Real>> {
417 if m < n || n != rank as usize {
418 return None;
419 }
420 Some(
421 b.slice(s![n.., ..])
422 .mapv(|x| x.powi(2).abs())
423 .sum_axis(Axis(0)),
424 )
425}
426
427#[cfg(test)]
428mod tests {
429 use crate::{error::LinalgError, *};
430 use approx::AbsDiffEq;
431 use ndarray::*;
432
433 //
434 // Test that the different least squares traits work as intended on the
435 // different array types.
436 //
437 // | least_squares | ls_into | ls_in_place |
438 // --------------+---------------+---------+-------------+
439 // Array | yes | yes | yes |
440 // ArcArray | yes | no | no |
441 // CowArray | yes | yes | yes |
442 // ArrayView | yes | no | no |
443 // ArrayViewMut | yes | no | yes |
444 //
445
446 fn assert_result<D1: Data<Elem = f64>, D2: Data<Elem = f64>>(
447 a: &ArrayBase<D1, Ix2>,
448 b: &ArrayBase<D2, Ix1>,
449 res: &LeastSquaresResult<f64, Ix1>,
450 ) {
451 assert_eq!(res.rank, 2);
452 let b_hat = a.dot(&res.solution);
453 let rssq = (b - &b_hat).mapv(|x| x.powi(2)).sum();
454 assert!(res.residual_sum_of_squares.as_ref().unwrap()[()].abs_diff_eq(&rssq, 1e-12));
455 assert!(res
456 .solution
457 .abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-12));
458 }
459
460 #[test]
461 fn on_arc() {
462 let a: ArcArray2<f64> = array![[1., 2.], [4., 5.], [3., 4.]].into_shared();
463 let b: ArcArray1<f64> = array![1., 2., 3.].into_shared();
464 let res = a.least_squares(&b).unwrap();
465 assert_result(&a, &b, &res);
466 }
467
468 #[test]
469 fn on_cow() {
470 let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
471 let b = CowArray::from(array![1., 2., 3.]);
472 let res = a.least_squares(&b).unwrap();
473 assert_result(&a, &b, &res);
474 }
475
476 #[test]
477 fn on_view() {
478 let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
479 let b: Array1<f64> = array![1., 2., 3.];
480 let av = a.view();
481 let bv = b.view();
482 let res = av.least_squares(&bv).unwrap();
483 assert_result(&av, &bv, &res);
484 }
485
486 #[test]
487 fn on_view_mut() {
488 let mut a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
489 let mut b: Array1<f64> = array![1., 2., 3.];
490 let av = a.view_mut();
491 let bv = b.view_mut();
492 let res = av.least_squares(&bv).unwrap();
493 assert_result(&av, &bv, &res);
494 }
495
496 #[test]
497 fn on_cow_view() {
498 let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
499 let b: Array1<f64> = array![1., 2., 3.];
500 let bv = b.view();
501 let res = a.least_squares(&bv).unwrap();
502 assert_result(&a, &bv, &res);
503 }
504
505 #[test]
506 fn into_on_owned() {
507 let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
508 let b: Array1<f64> = array![1., 2., 3.];
509 let ac = a.clone();
510 let bc = b.clone();
511 let res = ac.least_squares_into(bc).unwrap();
512 assert_result(&a, &b, &res);
513 }
514
515 #[test]
516 fn into_on_arc() {
517 let a: ArcArray2<f64> = array![[1., 2.], [4., 5.], [3., 4.]].into_shared();
518 let b: ArcArray1<f64> = array![1., 2., 3.].into_shared();
519 let a2 = a.clone();
520 let b2 = b.clone();
521 let res = a2.least_squares_into(b2).unwrap();
522 assert_result(&a, &b, &res);
523 }
524
525 #[test]
526 fn into_on_cow() {
527 let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
528 let b = CowArray::from(array![1., 2., 3.]);
529 let a2 = a.clone();
530 let b2 = b.clone();
531 let res = a2.least_squares_into(b2).unwrap();
532 assert_result(&a, &b, &res);
533 }
534
535 #[test]
536 fn into_on_owned_cow() {
537 let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
538 let b = CowArray::from(array![1., 2., 3.]);
539 let ac = a.clone();
540 let b2 = b.clone();
541 let res = ac.least_squares_into(b2).unwrap();
542 assert_result(&a, &b, &res);
543 }
544
545 #[test]
546 fn in_place_on_owned() {
547 let a = array![[1., 2.], [4., 5.], [3., 4.]];
548 let b = array![1., 2., 3.];
549 let mut a2 = a.clone();
550 let mut b2 = b.clone();
551 let res = a2.least_squares_in_place(&mut b2).unwrap();
552 assert_result(&a, &b, &res);
553 }
554
555 #[test]
556 fn in_place_on_cow() {
557 let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
558 let b = CowArray::from(array![1., 2., 3.]);
559 let mut a2 = a.clone();
560 let mut b2 = b.clone();
561 let res = a2.least_squares_in_place(&mut b2).unwrap();
562 assert_result(&a, &b, &res);
563 }
564
565 #[test]
566 fn in_place_on_mut_view() {
567 let a = array![[1., 2.], [4., 5.], [3., 4.]];
568 let b = array![1., 2., 3.];
569 let mut a2 = a.clone();
570 let mut b2 = b.clone();
571 let av = &mut a2.view_mut();
572 let bv = &mut b2.view_mut();
573 let res = av.least_squares_in_place(bv).unwrap();
574 assert_result(&a, &b, &res);
575 }
576
577 #[test]
578 fn in_place_on_owned_cow() {
579 let a = array![[1., 2.], [4., 5.], [3., 4.]];
580 let b = CowArray::from(array![1., 2., 3.]);
581 let mut a2 = a.clone();
582 let mut b2 = b.clone();
583 let res = a2.least_squares_in_place(&mut b2).unwrap();
584 assert_result(&a, &b, &res);
585 }
586
587 //
588 // Testing error cases
589 //
590 #[test]
591 fn incompatible_shape_error_on_mismatching_num_rows() {
592 let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
593 let b: Array1<f64> = array![1., 2.];
594 match a.least_squares(&b) {
595 Err(LinalgError::Shape(e)) if e.kind() == ErrorKind::IncompatibleShape => {}
596 _ => panic!("Should be raise IncompatibleShape"),
597 }
598 }
599}