ndarray_linalg/solve.rs
1//! Solve systems of linear equations and invert matrices
2//!
3//! # Examples
4//!
5//! Solve `A * x = b`:
6//!
7//! ```
8//! use ndarray::prelude::*;
9//! use ndarray_linalg::Solve;
10//!
11//! let a: Array2<f64> = array![[3., 2., -1.], [2., -2., 4.], [-2., 1., -2.]];
12//! let b: Array1<f64> = array![1., -2., 0.];
13//! let x = a.solve_into(b).unwrap();
14//! assert!(x.abs_diff_eq(&array![1., -2., -2.], 1e-9));
15//! ```
16//!
17//! There are also special functions for solving `A^T * x = b` and
18//! `A^H * x = b`.
19//!
20//! If you are solving multiple systems of linear equations with the same
21//! coefficient matrix `A`, it's faster to compute the LU factorization once at
22//! the beginning than solving directly using `A`:
23//!
24//! ```
25//! use ndarray::prelude::*;
26//! use ndarray_linalg::*;
27//!
28//! /// Use fixed algorithm and seed of PRNG for reproducible test
29//! let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
30//!
31//! let a: Array2<f64> = random_using((3, 3), &mut rng);
32//! let f = a.factorize_into().unwrap(); // LU factorize A (A is consumed)
33//! for _ in 0..10 {
34//! let b: Array1<f64> = random_using(3, &mut rng);
35//! let x = f.solve_into(b).unwrap(); // Solve A * x = b using factorized L, U
36//! }
37//! ```
38
39use ndarray::*;
40use num_traits::{Float, Zero};
41
42use crate::convert::*;
43use crate::error::*;
44use crate::layout::*;
45use crate::opnorm::OperationNorm;
46use crate::types::*;
47
48pub use lax::{Pivot, Transpose};
49
50/// An interface for solving systems of linear equations.
51///
52/// There are three groups of methods:
53///
54/// * `solve*` (normal) methods solve `A * x = b` for `x`.
55/// * `solve_t*` (transpose) methods solve `A^T * x = b` for `x`.
56/// * `solve_h*` (Hermitian conjugate) methods solve `A^H * x = b` for `x`.
57///
58/// Within each group, there are three methods that handle ownership differently:
59///
60/// * `*` methods take a reference to `b` and return `x` as a new array.
61/// * `*_into` methods take ownership of `b`, store the result in it, and return it.
62/// * `*_inplace` methods take a mutable reference to `b` and store the result in that array.
63///
64/// If you plan to solve many equations with the same `A` matrix but different
65/// `b` vectors, it's faster to factor the `A` matrix once using the
66/// `Factorize` trait, and then solve using the `LUFactorized` struct.
67pub trait Solve<A: Scalar> {
68 /// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
69 /// is the argument, and `x` is the successful result.
70 ///
71 /// # Panics
72 ///
73 /// Panics if the length of `b` is not the equal to the number of columns
74 /// of `A`.
75 fn solve<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
76 let mut b = replicate(b);
77 self.solve_inplace(&mut b)?;
78 Ok(b)
79 }
80
81 /// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
82 /// is the argument, and `x` is the successful result.
83 ///
84 /// # Panics
85 ///
86 /// Panics if the length of `b` is not the equal to the number of columns
87 /// of `A`.
88 fn solve_into<S: DataMut<Elem = A>>(
89 &self,
90 mut b: ArrayBase<S, Ix1>,
91 ) -> Result<ArrayBase<S, Ix1>> {
92 self.solve_inplace(&mut b)?;
93 Ok(b)
94 }
95
96 /// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
97 /// is the argument, and `x` is the successful result.
98 ///
99 /// # Panics
100 ///
101 /// Panics if the length of `b` is not the equal to the number of columns
102 /// of `A`.
103 fn solve_inplace<'a, S: DataMut<Elem = A>>(
104 &self,
105 b: &'a mut ArrayBase<S, Ix1>,
106 ) -> Result<&'a mut ArrayBase<S, Ix1>>;
107
108 /// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
109 /// is the argument, and `x` is the successful result.
110 ///
111 /// # Panics
112 ///
113 /// Panics if the length of `b` is not the equal to the number of rows of
114 /// `A`.
115 fn solve_t<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
116 let mut b = replicate(b);
117 self.solve_t_inplace(&mut b)?;
118 Ok(b)
119 }
120
121 /// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
122 /// is the argument, and `x` is the successful result.
123 ///
124 /// # Panics
125 ///
126 /// Panics if the length of `b` is not the equal to the number of rows of
127 /// `A`.
128 fn solve_t_into<S: DataMut<Elem = A>>(
129 &self,
130 mut b: ArrayBase<S, Ix1>,
131 ) -> Result<ArrayBase<S, Ix1>> {
132 self.solve_t_inplace(&mut b)?;
133 Ok(b)
134 }
135
136 /// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
137 /// is the argument, and `x` is the successful result.
138 ///
139 /// # Panics
140 ///
141 /// Panics if the length of `b` is not the equal to the number of rows of
142 /// `A`.
143 fn solve_t_inplace<'a, S: DataMut<Elem = A>>(
144 &self,
145 b: &'a mut ArrayBase<S, Ix1>,
146 ) -> Result<&'a mut ArrayBase<S, Ix1>>;
147
148 /// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
149 /// is the argument, and `x` is the successful result.
150 ///
151 /// # Panics
152 ///
153 /// Panics if the length of `b` is not the equal to the number of rows of
154 /// `A`.
155 fn solve_h<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
156 let mut b = replicate(b);
157 self.solve_h_inplace(&mut b)?;
158 Ok(b)
159 }
160 /// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
161 /// is the argument, and `x` is the successful result.
162 ///
163 /// # Panics
164 ///
165 /// Panics if the length of `b` is not the equal to the number of rows of
166 /// `A`.
167 fn solve_h_into<S: DataMut<Elem = A>>(
168 &self,
169 mut b: ArrayBase<S, Ix1>,
170 ) -> Result<ArrayBase<S, Ix1>> {
171 self.solve_h_inplace(&mut b)?;
172 Ok(b)
173 }
174 /// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
175 /// is the argument, and `x` is the successful result.
176 ///
177 /// # Panics
178 ///
179 /// Panics if the length of `b` is not the equal to the number of rows of
180 /// `A`.
181 fn solve_h_inplace<'a, S: DataMut<Elem = A>>(
182 &self,
183 b: &'a mut ArrayBase<S, Ix1>,
184 ) -> Result<&'a mut ArrayBase<S, Ix1>>;
185}
186
187/// Represents the LU factorization of a matrix `A` as `A = P*L*U`.
188#[derive(Clone)]
189pub struct LUFactorized<S: Data + RawDataClone> {
190 /// The factors `L` and `U`; the unit diagonal elements of `L` are not
191 /// stored.
192 a: ArrayBase<S, Ix2>,
193 /// The pivot indices that define the permutation matrix `P`.
194 ipiv: Pivot,
195}
196
197impl<A, S> Solve<A> for LUFactorized<S>
198where
199 A: Scalar + Lapack,
200 S: Data<Elem = A> + RawDataClone,
201{
202 fn solve_inplace<'a, Sb>(
203 &self,
204 rhs: &'a mut ArrayBase<Sb, Ix1>,
205 ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
206 where
207 Sb: DataMut<Elem = A>,
208 {
209 assert_eq!(
210 rhs.len(),
211 self.a.len_of(Axis(1)),
212 "The length of `rhs` must be compatible with the shape of the factored matrix.",
213 );
214 A::solve(
215 self.a.square_layout()?,
216 Transpose::No,
217 self.a.as_allocated()?,
218 &self.ipiv,
219 rhs.as_slice_mut().unwrap(),
220 )?;
221 Ok(rhs)
222 }
223 fn solve_t_inplace<'a, Sb>(
224 &self,
225 rhs: &'a mut ArrayBase<Sb, Ix1>,
226 ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
227 where
228 Sb: DataMut<Elem = A>,
229 {
230 assert_eq!(
231 rhs.len(),
232 self.a.len_of(Axis(0)),
233 "The length of `rhs` must be compatible with the shape of the factored matrix.",
234 );
235 A::solve(
236 self.a.square_layout()?,
237 Transpose::Transpose,
238 self.a.as_allocated()?,
239 &self.ipiv,
240 rhs.as_slice_mut().unwrap(),
241 )?;
242 Ok(rhs)
243 }
244 fn solve_h_inplace<'a, Sb>(
245 &self,
246 rhs: &'a mut ArrayBase<Sb, Ix1>,
247 ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
248 where
249 Sb: DataMut<Elem = A>,
250 {
251 assert_eq!(
252 rhs.len(),
253 self.a.len_of(Axis(0)),
254 "The length of `rhs` must be compatible with the shape of the factored matrix.",
255 );
256 A::solve(
257 self.a.square_layout()?,
258 Transpose::Hermite,
259 self.a.as_allocated()?,
260 &self.ipiv,
261 rhs.as_slice_mut().unwrap(),
262 )?;
263 Ok(rhs)
264 }
265}
266
267impl<A, S> Solve<A> for ArrayBase<S, Ix2>
268where
269 A: Scalar + Lapack,
270 S: Data<Elem = A>,
271{
272 fn solve_inplace<'a, Sb>(
273 &self,
274 rhs: &'a mut ArrayBase<Sb, Ix1>,
275 ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
276 where
277 Sb: DataMut<Elem = A>,
278 {
279 let f = self.factorize()?;
280 f.solve_inplace(rhs)
281 }
282 fn solve_t_inplace<'a, Sb>(
283 &self,
284 rhs: &'a mut ArrayBase<Sb, Ix1>,
285 ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
286 where
287 Sb: DataMut<Elem = A>,
288 {
289 let f = self.factorize()?;
290 f.solve_t_inplace(rhs)
291 }
292 fn solve_h_inplace<'a, Sb>(
293 &self,
294 rhs: &'a mut ArrayBase<Sb, Ix1>,
295 ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
296 where
297 Sb: DataMut<Elem = A>,
298 {
299 let f = self.factorize()?;
300 f.solve_h_inplace(rhs)
301 }
302}
303
304/// An interface for computing LU factorizations of matrix refs.
305pub trait Factorize<S: Data + RawDataClone> {
306 /// Computes the LU factorization `A = P*L*U`, where `P` is a permutation
307 /// matrix.
308 fn factorize(&self) -> Result<LUFactorized<S>>;
309}
310
311/// An interface for computing LU factorizations of matrices.
312pub trait FactorizeInto<S: Data + RawDataClone> {
313 /// Computes the LU factorization `A = P*L*U`, where `P` is a permutation
314 /// matrix.
315 fn factorize_into(self) -> Result<LUFactorized<S>>;
316}
317
318impl<A, S> FactorizeInto<S> for ArrayBase<S, Ix2>
319where
320 A: Scalar + Lapack,
321 S: DataMut<Elem = A> + RawDataClone,
322{
323 fn factorize_into(mut self) -> Result<LUFactorized<S>> {
324 let ipiv = A::lu(self.layout()?, self.as_allocated_mut()?)?;
325 Ok(LUFactorized { a: self, ipiv })
326 }
327}
328
329impl<A, Si> Factorize<OwnedRepr<A>> for ArrayBase<Si, Ix2>
330where
331 A: Scalar + Lapack,
332 Si: Data<Elem = A>,
333{
334 fn factorize(&self) -> Result<LUFactorized<OwnedRepr<A>>> {
335 let mut a: Array2<A> = replicate(self);
336 let ipiv = A::lu(a.layout()?, a.as_allocated_mut()?)?;
337 Ok(LUFactorized { a, ipiv })
338 }
339}
340
341/// An interface for inverting matrix refs.
342pub trait Inverse {
343 type Output;
344 /// Computes the inverse of the matrix.
345 fn inv(&self) -> Result<Self::Output>;
346}
347
348/// An interface for inverting matrices.
349pub trait InverseInto {
350 type Output;
351 /// Computes the inverse of the matrix.
352 fn inv_into(self) -> Result<Self::Output>;
353}
354
355impl<A, S> InverseInto for LUFactorized<S>
356where
357 A: Scalar + Lapack,
358 S: DataMut<Elem = A> + RawDataClone,
359{
360 type Output = ArrayBase<S, Ix2>;
361
362 fn inv_into(mut self) -> Result<ArrayBase<S, Ix2>> {
363 A::inv(
364 self.a.square_layout()?,
365 self.a.as_allocated_mut()?,
366 &self.ipiv,
367 )?;
368 Ok(self.a)
369 }
370}
371
372impl<A, S> Inverse for LUFactorized<S>
373where
374 A: Scalar + Lapack,
375 S: Data<Elem = A> + RawDataClone,
376{
377 type Output = Array2<A>;
378
379 fn inv(&self) -> Result<Array2<A>> {
380 // Preserve the existing layout. This is required to obtain the correct
381 // result, because the result of `A::inv` is layout-dependent.
382 let a = if self.a.is_standard_layout() {
383 replicate(&self.a)
384 } else {
385 replicate(&self.a.t()).reversed_axes()
386 };
387 let f = LUFactorized {
388 a,
389 ipiv: self.ipiv.clone(),
390 };
391 f.inv_into()
392 }
393}
394
395impl<A, S> InverseInto for ArrayBase<S, Ix2>
396where
397 A: Scalar + Lapack,
398 S: DataMut<Elem = A> + RawDataClone,
399{
400 type Output = Self;
401
402 fn inv_into(self) -> Result<Self::Output> {
403 let f = self.factorize_into()?;
404 f.inv_into()
405 }
406}
407
408impl<A, Si> Inverse for ArrayBase<Si, Ix2>
409where
410 A: Scalar + Lapack,
411 Si: Data<Elem = A>,
412{
413 type Output = Array2<A>;
414
415 fn inv(&self) -> Result<Self::Output> {
416 let f = self.factorize()?;
417 f.inv_into()
418 }
419}
420
421/// An interface for calculating determinants of matrix refs.
422pub trait Determinant<A: Scalar> {
423 /// Computes the determinant of the matrix.
424 fn det(&self) -> Result<A> {
425 let (sign, ln_det) = self.sln_det()?;
426 Ok(sign * A::from_real(Float::exp(ln_det)))
427 }
428
429 /// Computes the `(sign, natural_log)` of the determinant of the matrix.
430 ///
431 /// For real matrices, `sign` is `1`, `0`, or `-1`. For complex matrices,
432 /// `sign` is `0` or a complex number with absolute value 1. The
433 /// `natural_log` is the natural logarithm of the absolute value of the
434 /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
435 /// is negative infinity.
436 ///
437 /// To obtain the determinant, you can compute `sign * natural_log.exp()`
438 /// or just call `.det()` instead.
439 ///
440 /// This method is more robust than `.det()` to very small or very large
441 /// determinants since it returns the natural logarithm of the determinant
442 /// rather than the determinant itself.
443 fn sln_det(&self) -> Result<(A, A::Real)>;
444}
445
446/// An interface for calculating determinants of matrices.
447pub trait DeterminantInto<A: Scalar>: Sized {
448 /// Computes the determinant of the matrix.
449 fn det_into(self) -> Result<A> {
450 let (sign, ln_det) = self.sln_det_into()?;
451 Ok(sign * A::from_real(Float::exp(ln_det)))
452 }
453
454 /// Computes the `(sign, natural_log)` of the determinant of the matrix.
455 ///
456 /// For real matrices, `sign` is `1`, `0`, or `-1`. For complex matrices,
457 /// `sign` is `0` or a complex number with absolute value 1. The
458 /// `natural_log` is the natural logarithm of the absolute value of the
459 /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
460 /// is negative infinity.
461 ///
462 /// To obtain the determinant, you can compute `sign * natural_log.exp()`
463 /// or just call `.det_into()` instead.
464 ///
465 /// This method is more robust than `.det()` to very small or very large
466 /// determinants since it returns the natural logarithm of the determinant
467 /// rather than the determinant itself.
468 fn sln_det_into(self) -> Result<(A, A::Real)>;
469}
470
471fn lu_sln_det<'a, A, P, U>(ipiv_iter: P, u_diag_iter: U) -> (A, A::Real)
472where
473 A: Scalar + Lapack,
474 P: Iterator<Item = i32>,
475 U: Iterator<Item = &'a A>,
476{
477 let pivot_sign = if ipiv_iter
478 .enumerate()
479 .filter(|&(i, pivot)| pivot != i as i32 + 1)
480 .count()
481 % 2
482 == 0
483 {
484 A::one()
485 } else {
486 -A::one()
487 };
488 let (upper_sign, ln_det) = u_diag_iter.fold(
489 (A::one(), A::Real::zero()),
490 |(upper_sign, ln_det), &elem| {
491 let abs_elem: A::Real = elem.abs();
492 (
493 upper_sign * elem / A::from_real(abs_elem),
494 ln_det + Float::ln(abs_elem),
495 )
496 },
497 );
498 (pivot_sign * upper_sign, ln_det)
499}
500
501impl<A, S> Determinant<A> for LUFactorized<S>
502where
503 A: Scalar + Lapack,
504 S: Data<Elem = A> + RawDataClone,
505{
506 fn sln_det(&self) -> Result<(A, A::Real)> {
507 self.a.ensure_square()?;
508 Ok(lu_sln_det(self.ipiv.iter().cloned(), self.a.diag().iter()))
509 }
510}
511
512impl<A, S> DeterminantInto<A> for LUFactorized<S>
513where
514 A: Scalar + Lapack,
515 S: Data<Elem = A> + RawDataClone,
516{
517 fn sln_det_into(self) -> Result<(A, A::Real)> {
518 self.a.ensure_square()?;
519 Ok(lu_sln_det(self.ipiv.into_iter(), self.a.into_diag().iter()))
520 }
521}
522
523impl<A, S> Determinant<A> for ArrayBase<S, Ix2>
524where
525 A: Scalar + Lapack,
526 S: Data<Elem = A>,
527{
528 fn sln_det(&self) -> Result<(A, A::Real)> {
529 self.ensure_square()?;
530 match self.factorize() {
531 Ok(fac) => fac.sln_det(),
532 Err(LinalgError::Lapack(e))
533 if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
534 {
535 // The determinant is zero.
536 Ok((A::zero(), A::Real::neg_infinity()))
537 }
538 Err(err) => Err(err),
539 }
540 }
541}
542
543impl<A, S> DeterminantInto<A> for ArrayBase<S, Ix2>
544where
545 A: Scalar + Lapack,
546 S: DataMut<Elem = A> + RawDataClone,
547{
548 fn sln_det_into(self) -> Result<(A, A::Real)> {
549 self.ensure_square()?;
550 match self.factorize_into() {
551 Ok(fac) => fac.sln_det_into(),
552 Err(LinalgError::Lapack(e))
553 if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
554 {
555 // The determinant is zero.
556 Ok((A::zero(), A::Real::neg_infinity()))
557 }
558 Err(err) => Err(err),
559 }
560 }
561}
562
563/// An interface for *estimating* the reciprocal condition number of matrix refs.
564pub trait ReciprocalConditionNum<A: Scalar> {
565 /// *Estimates* the reciprocal of the condition number of the matrix in
566 /// 1-norm.
567 ///
568 /// This method uses the LAPACK `*gecon` routines, which *estimate*
569 /// `self.inv().opnorm_one()` and then compute `rcond = 1. /
570 /// (self.opnorm_one() * self.inv().opnorm_one())`.
571 ///
572 /// * If `rcond` is near `0.`, the matrix is badly conditioned.
573 /// * If `rcond` is near `1.`, the matrix is well conditioned.
574 fn rcond(&self) -> Result<A::Real>;
575}
576
577/// An interface for *estimating* the reciprocal condition number of matrices.
578pub trait ReciprocalConditionNumInto<A: Scalar> {
579 /// *Estimates* the reciprocal of the condition number of the matrix in
580 /// 1-norm.
581 ///
582 /// This method uses the LAPACK `*gecon` routines, which *estimate*
583 /// `self.inv().opnorm_one()` and then compute `rcond = 1. /
584 /// (self.opnorm_one() * self.inv().opnorm_one())`.
585 ///
586 /// * If `rcond` is near `0.`, the matrix is badly conditioned.
587 /// * If `rcond` is near `1.`, the matrix is well conditioned.
588 fn rcond_into(self) -> Result<A::Real>;
589}
590
591impl<A, S> ReciprocalConditionNum<A> for LUFactorized<S>
592where
593 A: Scalar + Lapack,
594 S: Data<Elem = A> + RawDataClone,
595{
596 fn rcond(&self) -> Result<A::Real> {
597 Ok(A::rcond(
598 self.a.layout()?,
599 self.a.as_allocated()?,
600 self.a.opnorm_one()?,
601 )?)
602 }
603}
604
605impl<A, S> ReciprocalConditionNumInto<A> for LUFactorized<S>
606where
607 A: Scalar + Lapack,
608 S: Data<Elem = A> + RawDataClone,
609{
610 fn rcond_into(self) -> Result<A::Real> {
611 self.rcond()
612 }
613}
614
615impl<A, S> ReciprocalConditionNum<A> for ArrayBase<S, Ix2>
616where
617 A: Scalar + Lapack,
618 S: Data<Elem = A>,
619{
620 fn rcond(&self) -> Result<A::Real> {
621 self.factorize()?.rcond_into()
622 }
623}
624
625impl<A, S> ReciprocalConditionNumInto<A> for ArrayBase<S, Ix2>
626where
627 A: Scalar + Lapack,
628 S: DataMut<Elem = A> + RawDataClone,
629{
630 fn rcond_into(self) -> Result<A::Real> {
631 self.factorize_into()?.rcond_into()
632 }
633}