ndarray_linalg/
solve.rs

1//! Solve systems of linear equations and invert matrices
2//!
3//! # Examples
4//!
5//! Solve `A * x = b`:
6//!
7//! ```
8//! use ndarray::prelude::*;
9//! use ndarray_linalg::Solve;
10//!
11//! let a: Array2<f64> = array![[3., 2., -1.], [2., -2., 4.], [-2., 1., -2.]];
12//! let b: Array1<f64> = array![1., -2., 0.];
13//! let x = a.solve_into(b).unwrap();
14//! assert!(x.abs_diff_eq(&array![1., -2., -2.], 1e-9));
15//! ```
16//!
17//! There are also special functions for solving `A^T * x = b` and
18//! `A^H * x = b`.
19//!
20//! If you are solving multiple systems of linear equations with the same
21//! coefficient matrix `A`, it's faster to compute the LU factorization once at
22//! the beginning than solving directly using `A`:
23//!
24//! ```
25//! use ndarray::prelude::*;
26//! use ndarray_linalg::*;
27//!
28//! /// Use fixed algorithm and seed of PRNG for reproducible test
29//! let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
30//!
31//! let a: Array2<f64> = random_using((3, 3), &mut rng);
32//! let f = a.factorize_into().unwrap(); // LU factorize A (A is consumed)
33//! for _ in 0..10 {
34//!     let b: Array1<f64> = random_using(3, &mut  rng);
35//!     let x = f.solve_into(b).unwrap(); // Solve A * x = b using factorized L, U
36//! }
37//! ```
38
39use ndarray::*;
40use num_traits::{Float, Zero};
41
42use crate::convert::*;
43use crate::error::*;
44use crate::layout::*;
45use crate::opnorm::OperationNorm;
46use crate::types::*;
47
48pub use lax::{Pivot, Transpose};
49
50/// An interface for solving systems of linear equations.
51///
52/// There are three groups of methods:
53///
54/// * `solve*` (normal) methods solve `A * x = b` for `x`.
55/// * `solve_t*` (transpose) methods solve `A^T * x = b` for `x`.
56/// * `solve_h*` (Hermitian conjugate) methods solve `A^H * x = b` for `x`.
57///
58/// Within each group, there are three methods that handle ownership differently:
59///
60/// * `*` methods take a reference to `b` and return `x` as a new array.
61/// * `*_into` methods take ownership of `b`, store the result in it, and return it.
62/// * `*_inplace` methods take a mutable reference to `b` and store the result in that array.
63///
64/// If you plan to solve many equations with the same `A` matrix but different
65/// `b` vectors, it's faster to factor the `A` matrix once using the
66/// `Factorize` trait, and then solve using the `LUFactorized` struct.
67pub trait Solve<A: Scalar> {
68    /// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
69    /// is the argument, and `x` is the successful result.
70    ///
71    /// # Panics
72    ///
73    /// Panics if the length of `b` is not the equal to the number of columns
74    /// of `A`.
75    fn solve(&self, b: &ArrayRef<A, Ix1>) -> Result<Array1<A>> {
76        let mut b = replicate(b);
77        self.solve_inplace(&mut b)?;
78        Ok(b)
79    }
80
81    /// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
82    /// is the argument, and `x` is the successful result.
83    ///
84    /// # Panics
85    ///
86    /// Panics if the length of `b` is not the equal to the number of columns
87    /// of `A`.
88    fn solve_into<S: DataMut<Elem = A>>(
89        &self,
90        mut b: ArrayBase<S, Ix1>,
91    ) -> Result<ArrayBase<S, Ix1>> {
92        self.solve_inplace(&mut b)?;
93        Ok(b)
94    }
95
96    /// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
97    /// is the argument, and `x` is the successful result.
98    ///
99    /// # Panics
100    ///
101    /// Panics if the length of `b` is not the equal to the number of columns
102    /// of `A`.
103    fn solve_inplace<'a>(&self, b: &'a mut ArrayRef<A, Ix1>) -> Result<&'a mut ArrayRef<A, Ix1>>;
104
105    /// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
106    /// is the argument, and `x` is the successful result.
107    ///
108    /// # Panics
109    ///
110    /// Panics if the length of `b` is not the equal to the number of rows of
111    /// `A`.
112    fn solve_t(&self, b: &ArrayRef<A, Ix1>) -> Result<Array1<A>> {
113        let mut b = replicate(b);
114        self.solve_t_inplace(&mut b)?;
115        Ok(b)
116    }
117
118    /// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
119    /// is the argument, and `x` is the successful result.
120    ///
121    /// # Panics
122    ///
123    /// Panics if the length of `b` is not the equal to the number of rows of
124    /// `A`.
125    fn solve_t_into<S: DataMut<Elem = A>>(
126        &self,
127        mut b: ArrayBase<S, Ix1>,
128    ) -> Result<ArrayBase<S, Ix1>> {
129        self.solve_t_inplace(&mut b)?;
130        Ok(b)
131    }
132
133    /// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
134    /// is the argument, and `x` is the successful result.
135    ///
136    /// # Panics
137    ///
138    /// Panics if the length of `b` is not the equal to the number of rows of
139    /// `A`.
140    fn solve_t_inplace<'a>(&self, b: &'a mut ArrayRef<A, Ix1>) -> Result<&'a mut ArrayRef<A, Ix1>>;
141
142    /// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
143    /// is the argument, and `x` is the successful result.
144    ///
145    /// # Panics
146    ///
147    /// Panics if the length of `b` is not the equal to the number of rows of
148    /// `A`.
149    fn solve_h(&self, b: &ArrayRef<A, Ix1>) -> Result<Array1<A>> {
150        let mut b = replicate(b);
151        self.solve_h_inplace(&mut b)?;
152        Ok(b)
153    }
154    /// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
155    /// is the argument, and `x` is the successful result.
156    ///
157    /// # Panics
158    ///
159    /// Panics if the length of `b` is not the equal to the number of rows of
160    /// `A`.
161    fn solve_h_into<S: DataMut<Elem = A>>(
162        &self,
163        mut b: ArrayBase<S, Ix1>,
164    ) -> Result<ArrayBase<S, Ix1>> {
165        self.solve_h_inplace(&mut b)?;
166        Ok(b)
167    }
168    /// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
169    /// is the argument, and `x` is the successful result.
170    ///
171    /// # Panics
172    ///
173    /// Panics if the length of `b` is not the equal to the number of rows of
174    /// `A`.
175    fn solve_h_inplace<'a>(&self, b: &'a mut ArrayRef<A, Ix1>) -> Result<&'a mut ArrayRef<A, Ix1>>;
176}
177
178/// Represents the LU factorization of a matrix `A` as `A = P*L*U`.
179#[derive(Clone)]
180pub struct LUFactorized<S: Data + RawDataClone> {
181    /// The factors `L` and `U`; the unit diagonal elements of `L` are not
182    /// stored.
183    a: ArrayBase<S, Ix2>,
184    /// The pivot indices that define the permutation matrix `P`.
185    ipiv: Pivot,
186}
187
188impl<A, S> Solve<A> for LUFactorized<S>
189where
190    A: Scalar + Lapack,
191    S: Data<Elem = A> + RawDataClone,
192{
193    fn solve_inplace<'a>(&self, rhs: &'a mut ArrayRef<A, Ix1>) -> Result<&'a mut ArrayRef<A, Ix1>> {
194        assert_eq!(
195            rhs.len(),
196            self.a.len_of(Axis(1)),
197            "The length of `rhs` must be compatible with the shape of the factored matrix.",
198        );
199        A::solve(
200            self.a.square_layout()?,
201            Transpose::No,
202            self.a.as_allocated()?,
203            &self.ipiv,
204            rhs.as_slice_mut().unwrap(),
205        )?;
206        Ok(rhs)
207    }
208    fn solve_t_inplace<'a>(
209        &self,
210        rhs: &'a mut ArrayRef<A, Ix1>,
211    ) -> Result<&'a mut ArrayRef<A, Ix1>> {
212        assert_eq!(
213            rhs.len(),
214            self.a.len_of(Axis(0)),
215            "The length of `rhs` must be compatible with the shape of the factored matrix.",
216        );
217        A::solve(
218            self.a.square_layout()?,
219            Transpose::Transpose,
220            self.a.as_allocated()?,
221            &self.ipiv,
222            rhs.as_slice_mut().unwrap(),
223        )?;
224        Ok(rhs)
225    }
226    fn solve_h_inplace<'a>(
227        &self,
228        rhs: &'a mut ArrayRef<A, Ix1>,
229    ) -> Result<&'a mut ArrayRef<A, Ix1>> {
230        assert_eq!(
231            rhs.len(),
232            self.a.len_of(Axis(0)),
233            "The length of `rhs` must be compatible with the shape of the factored matrix.",
234        );
235        A::solve(
236            self.a.square_layout()?,
237            Transpose::Hermite,
238            self.a.as_allocated()?,
239            &self.ipiv,
240            rhs.as_slice_mut().unwrap(),
241        )?;
242        Ok(rhs)
243    }
244}
245
246impl<A> Solve<A> for ArrayRef<A, Ix2>
247where
248    A: Scalar + Lapack,
249{
250    fn solve_inplace<'a>(&self, rhs: &'a mut ArrayRef<A, Ix1>) -> Result<&'a mut ArrayRef<A, Ix1>> {
251        let f = self.factorize()?;
252        f.solve_inplace(rhs)
253    }
254    fn solve_t_inplace<'a>(
255        &self,
256        rhs: &'a mut ArrayRef<A, Ix1>,
257    ) -> Result<&'a mut ArrayRef<A, Ix1>> {
258        let f = self.factorize()?;
259        f.solve_t_inplace(rhs)
260    }
261    fn solve_h_inplace<'a>(
262        &self,
263        rhs: &'a mut ArrayRef<A, Ix1>,
264    ) -> Result<&'a mut ArrayRef<A, Ix1>> {
265        let f = self.factorize()?;
266        f.solve_h_inplace(rhs)
267    }
268}
269
270/// An interface for computing LU factorizations of matrix refs.
271pub trait Factorize<S: Data + RawDataClone> {
272    /// Computes the LU factorization `A = P*L*U`, where `P` is a permutation
273    /// matrix.
274    fn factorize(&self) -> Result<LUFactorized<S>>;
275}
276
277/// An interface for computing LU factorizations of matrices.
278pub trait FactorizeInto<S: Data + RawDataClone> {
279    /// Computes the LU factorization `A = P*L*U`, where `P` is a permutation
280    /// matrix.
281    fn factorize_into(self) -> Result<LUFactorized<S>>;
282}
283
284impl<A, S> FactorizeInto<S> for ArrayBase<S, Ix2>
285where
286    A: Scalar + Lapack,
287    S: DataMut<Elem = A> + RawDataClone,
288{
289    fn factorize_into(mut self) -> Result<LUFactorized<S>> {
290        let ipiv = A::lu(self.layout()?, self.as_allocated_mut()?)?;
291        Ok(LUFactorized { a: self, ipiv })
292    }
293}
294
295impl<A> Factorize<OwnedRepr<A>> for ArrayRef<A, Ix2>
296where
297    A: Scalar + Lapack,
298{
299    fn factorize(&self) -> Result<LUFactorized<OwnedRepr<A>>> {
300        let mut a: Array2<A> = replicate(self);
301        let ipiv = A::lu(a.layout()?, a.as_allocated_mut()?)?;
302        Ok(LUFactorized { a, ipiv })
303    }
304}
305
306/// An interface for inverting matrix refs.
307pub trait Inverse {
308    type Output;
309    /// Computes the inverse of the matrix.
310    fn inv(&self) -> Result<Self::Output>;
311}
312
313/// An interface for inverting matrices.
314pub trait InverseInto {
315    type Output;
316    /// Computes the inverse of the matrix.
317    fn inv_into(self) -> Result<Self::Output>;
318}
319
320impl<A, S> InverseInto for LUFactorized<S>
321where
322    A: Scalar + Lapack,
323    S: DataMut<Elem = A> + RawDataClone,
324{
325    type Output = ArrayBase<S, Ix2>;
326
327    fn inv_into(mut self) -> Result<ArrayBase<S, Ix2>> {
328        A::inv(
329            self.a.square_layout()?,
330            self.a.as_allocated_mut()?,
331            &self.ipiv,
332        )?;
333        Ok(self.a)
334    }
335}
336
337impl<A, S> Inverse for LUFactorized<S>
338where
339    A: Scalar + Lapack,
340    S: Data<Elem = A> + RawDataClone,
341{
342    type Output = Array2<A>;
343
344    fn inv(&self) -> Result<Array2<A>> {
345        // Preserve the existing layout. This is required to obtain the correct
346        // result, because the result of `A::inv` is layout-dependent.
347        let a = if self.a.is_standard_layout() {
348            replicate(&self.a)
349        } else {
350            replicate(&self.a.t()).reversed_axes()
351        };
352        let f = LUFactorized {
353            a,
354            ipiv: self.ipiv.clone(),
355        };
356        f.inv_into()
357    }
358}
359
360impl<A, S> InverseInto for ArrayBase<S, Ix2>
361where
362    A: Scalar + Lapack,
363    S: DataMut<Elem = A> + RawDataClone,
364{
365    type Output = Self;
366
367    fn inv_into(self) -> Result<Self::Output> {
368        let f = self.factorize_into()?;
369        f.inv_into()
370    }
371}
372
373impl<A> Inverse for ArrayRef<A, Ix2>
374where
375    A: Scalar + Lapack,
376{
377    type Output = Array2<A>;
378
379    fn inv(&self) -> Result<Self::Output> {
380        let f = self.factorize()?;
381        f.inv_into()
382    }
383}
384
385/// An interface for calculating determinants of matrix refs.
386pub trait Determinant<A: Scalar> {
387    /// Computes the determinant of the matrix.
388    fn det(&self) -> Result<A> {
389        let (sign, ln_det) = self.sln_det()?;
390        Ok(sign * A::from_real(Float::exp(ln_det)))
391    }
392
393    /// Computes the `(sign, natural_log)` of the determinant of the matrix.
394    ///
395    /// For real matrices, `sign` is `1`, `0`, or `-1`. For complex matrices,
396    /// `sign` is `0` or a complex number with absolute value 1. The
397    /// `natural_log` is the natural logarithm of the absolute value of the
398    /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
399    /// is negative infinity.
400    ///
401    /// To obtain the determinant, you can compute `sign * natural_log.exp()`
402    /// or just call `.det()` instead.
403    ///
404    /// This method is more robust than `.det()` to very small or very large
405    /// determinants since it returns the natural logarithm of the determinant
406    /// rather than the determinant itself.
407    fn sln_det(&self) -> Result<(A, A::Real)>;
408}
409
410/// An interface for calculating determinants of matrices.
411pub trait DeterminantInto<A: Scalar>: Sized {
412    /// Computes the determinant of the matrix.
413    fn det_into(self) -> Result<A> {
414        let (sign, ln_det) = self.sln_det_into()?;
415        Ok(sign * A::from_real(Float::exp(ln_det)))
416    }
417
418    /// Computes the `(sign, natural_log)` of the determinant of the matrix.
419    ///
420    /// For real matrices, `sign` is `1`, `0`, or `-1`. For complex matrices,
421    /// `sign` is `0` or a complex number with absolute value 1. The
422    /// `natural_log` is the natural logarithm of the absolute value of the
423    /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
424    /// is negative infinity.
425    ///
426    /// To obtain the determinant, you can compute `sign * natural_log.exp()`
427    /// or just call `.det_into()` instead.
428    ///
429    /// This method is more robust than `.det()` to very small or very large
430    /// determinants since it returns the natural logarithm of the determinant
431    /// rather than the determinant itself.
432    fn sln_det_into(self) -> Result<(A, A::Real)>;
433}
434
435fn lu_sln_det<'a, A, P, U>(ipiv_iter: P, u_diag_iter: U) -> (A, A::Real)
436where
437    A: Scalar + Lapack,
438    P: Iterator<Item = i32>,
439    U: Iterator<Item = &'a A>,
440{
441    let pivot_sign = if ipiv_iter
442        .enumerate()
443        .filter(|&(i, pivot)| pivot != i as i32 + 1)
444        .count()
445        % 2
446        == 0
447    {
448        A::one()
449    } else {
450        -A::one()
451    };
452    let (upper_sign, ln_det) = u_diag_iter.fold(
453        (A::one(), A::Real::zero()),
454        |(upper_sign, ln_det), &elem| {
455            let abs_elem: A::Real = elem.abs();
456            (
457                upper_sign * elem / A::from_real(abs_elem),
458                ln_det + Float::ln(abs_elem),
459            )
460        },
461    );
462    (pivot_sign * upper_sign, ln_det)
463}
464
465impl<A, S> Determinant<A> for LUFactorized<S>
466where
467    A: Scalar + Lapack,
468    S: Data<Elem = A> + RawDataClone,
469{
470    fn sln_det(&self) -> Result<(A, A::Real)> {
471        self.a.ensure_square()?;
472        Ok(lu_sln_det(self.ipiv.iter().cloned(), self.a.diag().iter()))
473    }
474}
475
476impl<A, S> DeterminantInto<A> for LUFactorized<S>
477where
478    A: Scalar + Lapack,
479    S: Data<Elem = A> + RawDataClone,
480{
481    fn sln_det_into(self) -> Result<(A, A::Real)> {
482        self.a.ensure_square()?;
483        Ok(lu_sln_det(self.ipiv.into_iter(), self.a.into_diag().iter()))
484    }
485}
486
487impl<A> Determinant<A> for ArrayRef<A, Ix2>
488where
489    A: Scalar + Lapack,
490{
491    fn sln_det(&self) -> Result<(A, A::Real)> {
492        self.ensure_square()?;
493        match self.factorize() {
494            Ok(fac) => fac.sln_det(),
495            Err(LinalgError::Lapack(e))
496                if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
497            {
498                // The determinant is zero.
499                Ok((A::zero(), A::Real::neg_infinity()))
500            }
501            Err(err) => Err(err),
502        }
503    }
504}
505
506impl<A, S> DeterminantInto<A> for ArrayBase<S, Ix2>
507where
508    A: Scalar + Lapack,
509    S: DataMut<Elem = A> + RawDataClone,
510{
511    fn sln_det_into(self) -> Result<(A, A::Real)> {
512        self.ensure_square()?;
513        match self.factorize_into() {
514            Ok(fac) => fac.sln_det_into(),
515            Err(LinalgError::Lapack(e))
516                if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
517            {
518                // The determinant is zero.
519                Ok((A::zero(), A::Real::neg_infinity()))
520            }
521            Err(err) => Err(err),
522        }
523    }
524}
525
526/// An interface for *estimating* the reciprocal condition number of matrix refs.
527pub trait ReciprocalConditionNum<A: Scalar> {
528    /// *Estimates* the reciprocal of the condition number of the matrix in
529    /// 1-norm.
530    ///
531    /// This method uses the LAPACK `*gecon` routines, which *estimate*
532    /// `self.inv().opnorm_one()` and then compute `rcond = 1. /
533    /// (self.opnorm_one() * self.inv().opnorm_one())`.
534    ///
535    /// * If `rcond` is near `0.`, the matrix is badly conditioned.
536    /// * If `rcond` is near `1.`, the matrix is well conditioned.
537    fn rcond(&self) -> Result<A::Real>;
538}
539
540/// An interface for *estimating* the reciprocal condition number of matrices.
541pub trait ReciprocalConditionNumInto<A: Scalar> {
542    /// *Estimates* the reciprocal of the condition number of the matrix in
543    /// 1-norm.
544    ///
545    /// This method uses the LAPACK `*gecon` routines, which *estimate*
546    /// `self.inv().opnorm_one()` and then compute `rcond = 1. /
547    /// (self.opnorm_one() * self.inv().opnorm_one())`.
548    ///
549    /// * If `rcond` is near `0.`, the matrix is badly conditioned.
550    /// * If `rcond` is near `1.`, the matrix is well conditioned.
551    fn rcond_into(self) -> Result<A::Real>;
552}
553
554impl<A, S> ReciprocalConditionNum<A> for LUFactorized<S>
555where
556    A: Scalar + Lapack,
557    S: Data<Elem = A> + RawDataClone,
558{
559    fn rcond(&self) -> Result<A::Real> {
560        Ok(A::rcond(
561            self.a.layout()?,
562            self.a.as_allocated()?,
563            self.a.opnorm_one()?,
564        )?)
565    }
566}
567
568impl<A, S> ReciprocalConditionNumInto<A> for LUFactorized<S>
569where
570    A: Scalar + Lapack,
571    S: Data<Elem = A> + RawDataClone,
572{
573    fn rcond_into(self) -> Result<A::Real> {
574        self.rcond()
575    }
576}
577
578impl<A> ReciprocalConditionNum<A> for ArrayRef<A, Ix2>
579where
580    A: Scalar + Lapack,
581{
582    fn rcond(&self) -> Result<A::Real> {
583        self.factorize()?.rcond_into()
584    }
585}
586
587impl<A, S> ReciprocalConditionNumInto<A> for ArrayBase<S, Ix2>
588where
589    A: Scalar + Lapack,
590    S: DataMut<Elem = A> + RawDataClone,
591{
592    fn rcond_into(self) -> Result<A::Real> {
593        self.factorize_into()?.rcond_into()
594    }
595}