ndarray_linalg/
solve.rs

1//! Solve systems of linear equations and invert matrices
2//!
3//! # Examples
4//!
5//! Solve `A * x = b`:
6//!
7//! ```
8//! use ndarray::prelude::*;
9//! use ndarray_linalg::Solve;
10//!
11//! let a: Array2<f64> = array![[3., 2., -1.], [2., -2., 4.], [-2., 1., -2.]];
12//! let b: Array1<f64> = array![1., -2., 0.];
13//! let x = a.solve_into(b).unwrap();
14//! assert!(x.abs_diff_eq(&array![1., -2., -2.], 1e-9));
15//! ```
16//!
17//! There are also special functions for solving `A^T * x = b` and
18//! `A^H * x = b`.
19//!
20//! If you are solving multiple systems of linear equations with the same
21//! coefficient matrix `A`, it's faster to compute the LU factorization once at
22//! the beginning than solving directly using `A`:
23//!
24//! ```
25//! use ndarray::prelude::*;
26//! use ndarray_linalg::*;
27//!
28//! /// Use fixed algorithm and seed of PRNG for reproducible test
29//! let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
30//!
31//! let a: Array2<f64> = random_using((3, 3), &mut rng);
32//! let f = a.factorize_into().unwrap(); // LU factorize A (A is consumed)
33//! for _ in 0..10 {
34//!     let b: Array1<f64> = random_using(3, &mut  rng);
35//!     let x = f.solve_into(b).unwrap(); // Solve A * x = b using factorized L, U
36//! }
37//! ```
38
39use ndarray::*;
40use num_traits::{Float, Zero};
41
42use crate::convert::*;
43use crate::error::*;
44use crate::layout::*;
45use crate::opnorm::OperationNorm;
46use crate::types::*;
47
48pub use lax::{Pivot, Transpose};
49
50/// An interface for solving systems of linear equations.
51///
52/// There are three groups of methods:
53///
54/// * `solve*` (normal) methods solve `A * x = b` for `x`.
55/// * `solve_t*` (transpose) methods solve `A^T * x = b` for `x`.
56/// * `solve_h*` (Hermitian conjugate) methods solve `A^H * x = b` for `x`.
57///
58/// Within each group, there are three methods that handle ownership differently:
59///
60/// * `*` methods take a reference to `b` and return `x` as a new array.
61/// * `*_into` methods take ownership of `b`, store the result in it, and return it.
62/// * `*_inplace` methods take a mutable reference to `b` and store the result in that array.
63///
64/// If you plan to solve many equations with the same `A` matrix but different
65/// `b` vectors, it's faster to factor the `A` matrix once using the
66/// `Factorize` trait, and then solve using the `LUFactorized` struct.
67pub trait Solve<A: Scalar> {
68    /// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
69    /// is the argument, and `x` is the successful result.
70    ///
71    /// # Panics
72    ///
73    /// Panics if the length of `b` is not the equal to the number of columns
74    /// of `A`.
75    fn solve<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
76        let mut b = replicate(b);
77        self.solve_inplace(&mut b)?;
78        Ok(b)
79    }
80
81    /// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
82    /// is the argument, and `x` is the successful result.
83    ///
84    /// # Panics
85    ///
86    /// Panics if the length of `b` is not the equal to the number of columns
87    /// of `A`.
88    fn solve_into<S: DataMut<Elem = A>>(
89        &self,
90        mut b: ArrayBase<S, Ix1>,
91    ) -> Result<ArrayBase<S, Ix1>> {
92        self.solve_inplace(&mut b)?;
93        Ok(b)
94    }
95
96    /// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
97    /// is the argument, and `x` is the successful result.
98    ///
99    /// # Panics
100    ///
101    /// Panics if the length of `b` is not the equal to the number of columns
102    /// of `A`.
103    fn solve_inplace<'a, S: DataMut<Elem = A>>(
104        &self,
105        b: &'a mut ArrayBase<S, Ix1>,
106    ) -> Result<&'a mut ArrayBase<S, Ix1>>;
107
108    /// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
109    /// is the argument, and `x` is the successful result.
110    ///
111    /// # Panics
112    ///
113    /// Panics if the length of `b` is not the equal to the number of rows of
114    /// `A`.
115    fn solve_t<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
116        let mut b = replicate(b);
117        self.solve_t_inplace(&mut b)?;
118        Ok(b)
119    }
120
121    /// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
122    /// is the argument, and `x` is the successful result.
123    ///
124    /// # Panics
125    ///
126    /// Panics if the length of `b` is not the equal to the number of rows of
127    /// `A`.
128    fn solve_t_into<S: DataMut<Elem = A>>(
129        &self,
130        mut b: ArrayBase<S, Ix1>,
131    ) -> Result<ArrayBase<S, Ix1>> {
132        self.solve_t_inplace(&mut b)?;
133        Ok(b)
134    }
135
136    /// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
137    /// is the argument, and `x` is the successful result.
138    ///
139    /// # Panics
140    ///
141    /// Panics if the length of `b` is not the equal to the number of rows of
142    /// `A`.
143    fn solve_t_inplace<'a, S: DataMut<Elem = A>>(
144        &self,
145        b: &'a mut ArrayBase<S, Ix1>,
146    ) -> Result<&'a mut ArrayBase<S, Ix1>>;
147
148    /// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
149    /// is the argument, and `x` is the successful result.
150    ///
151    /// # Panics
152    ///
153    /// Panics if the length of `b` is not the equal to the number of rows of
154    /// `A`.
155    fn solve_h<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
156        let mut b = replicate(b);
157        self.solve_h_inplace(&mut b)?;
158        Ok(b)
159    }
160    /// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
161    /// is the argument, and `x` is the successful result.
162    ///
163    /// # Panics
164    ///
165    /// Panics if the length of `b` is not the equal to the number of rows of
166    /// `A`.
167    fn solve_h_into<S: DataMut<Elem = A>>(
168        &self,
169        mut b: ArrayBase<S, Ix1>,
170    ) -> Result<ArrayBase<S, Ix1>> {
171        self.solve_h_inplace(&mut b)?;
172        Ok(b)
173    }
174    /// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
175    /// is the argument, and `x` is the successful result.
176    ///
177    /// # Panics
178    ///
179    /// Panics if the length of `b` is not the equal to the number of rows of
180    /// `A`.
181    fn solve_h_inplace<'a, S: DataMut<Elem = A>>(
182        &self,
183        b: &'a mut ArrayBase<S, Ix1>,
184    ) -> Result<&'a mut ArrayBase<S, Ix1>>;
185}
186
187/// Represents the LU factorization of a matrix `A` as `A = P*L*U`.
188#[derive(Clone)]
189pub struct LUFactorized<S: Data + RawDataClone> {
190    /// The factors `L` and `U`; the unit diagonal elements of `L` are not
191    /// stored.
192    a: ArrayBase<S, Ix2>,
193    /// The pivot indices that define the permutation matrix `P`.
194    ipiv: Pivot,
195}
196
197impl<A, S> Solve<A> for LUFactorized<S>
198where
199    A: Scalar + Lapack,
200    S: Data<Elem = A> + RawDataClone,
201{
202    fn solve_inplace<'a, Sb>(
203        &self,
204        rhs: &'a mut ArrayBase<Sb, Ix1>,
205    ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
206    where
207        Sb: DataMut<Elem = A>,
208    {
209        assert_eq!(
210            rhs.len(),
211            self.a.len_of(Axis(1)),
212            "The length of `rhs` must be compatible with the shape of the factored matrix.",
213        );
214        A::solve(
215            self.a.square_layout()?,
216            Transpose::No,
217            self.a.as_allocated()?,
218            &self.ipiv,
219            rhs.as_slice_mut().unwrap(),
220        )?;
221        Ok(rhs)
222    }
223    fn solve_t_inplace<'a, Sb>(
224        &self,
225        rhs: &'a mut ArrayBase<Sb, Ix1>,
226    ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
227    where
228        Sb: DataMut<Elem = A>,
229    {
230        assert_eq!(
231            rhs.len(),
232            self.a.len_of(Axis(0)),
233            "The length of `rhs` must be compatible with the shape of the factored matrix.",
234        );
235        A::solve(
236            self.a.square_layout()?,
237            Transpose::Transpose,
238            self.a.as_allocated()?,
239            &self.ipiv,
240            rhs.as_slice_mut().unwrap(),
241        )?;
242        Ok(rhs)
243    }
244    fn solve_h_inplace<'a, Sb>(
245        &self,
246        rhs: &'a mut ArrayBase<Sb, Ix1>,
247    ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
248    where
249        Sb: DataMut<Elem = A>,
250    {
251        assert_eq!(
252            rhs.len(),
253            self.a.len_of(Axis(0)),
254            "The length of `rhs` must be compatible with the shape of the factored matrix.",
255        );
256        A::solve(
257            self.a.square_layout()?,
258            Transpose::Hermite,
259            self.a.as_allocated()?,
260            &self.ipiv,
261            rhs.as_slice_mut().unwrap(),
262        )?;
263        Ok(rhs)
264    }
265}
266
267impl<A, S> Solve<A> for ArrayBase<S, Ix2>
268where
269    A: Scalar + Lapack,
270    S: Data<Elem = A>,
271{
272    fn solve_inplace<'a, Sb>(
273        &self,
274        rhs: &'a mut ArrayBase<Sb, Ix1>,
275    ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
276    where
277        Sb: DataMut<Elem = A>,
278    {
279        let f = self.factorize()?;
280        f.solve_inplace(rhs)
281    }
282    fn solve_t_inplace<'a, Sb>(
283        &self,
284        rhs: &'a mut ArrayBase<Sb, Ix1>,
285    ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
286    where
287        Sb: DataMut<Elem = A>,
288    {
289        let f = self.factorize()?;
290        f.solve_t_inplace(rhs)
291    }
292    fn solve_h_inplace<'a, Sb>(
293        &self,
294        rhs: &'a mut ArrayBase<Sb, Ix1>,
295    ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
296    where
297        Sb: DataMut<Elem = A>,
298    {
299        let f = self.factorize()?;
300        f.solve_h_inplace(rhs)
301    }
302}
303
304/// An interface for computing LU factorizations of matrix refs.
305pub trait Factorize<S: Data + RawDataClone> {
306    /// Computes the LU factorization `A = P*L*U`, where `P` is a permutation
307    /// matrix.
308    fn factorize(&self) -> Result<LUFactorized<S>>;
309}
310
311/// An interface for computing LU factorizations of matrices.
312pub trait FactorizeInto<S: Data + RawDataClone> {
313    /// Computes the LU factorization `A = P*L*U`, where `P` is a permutation
314    /// matrix.
315    fn factorize_into(self) -> Result<LUFactorized<S>>;
316}
317
318impl<A, S> FactorizeInto<S> for ArrayBase<S, Ix2>
319where
320    A: Scalar + Lapack,
321    S: DataMut<Elem = A> + RawDataClone,
322{
323    fn factorize_into(mut self) -> Result<LUFactorized<S>> {
324        let ipiv = A::lu(self.layout()?, self.as_allocated_mut()?)?;
325        Ok(LUFactorized { a: self, ipiv })
326    }
327}
328
329impl<A, Si> Factorize<OwnedRepr<A>> for ArrayBase<Si, Ix2>
330where
331    A: Scalar + Lapack,
332    Si: Data<Elem = A>,
333{
334    fn factorize(&self) -> Result<LUFactorized<OwnedRepr<A>>> {
335        let mut a: Array2<A> = replicate(self);
336        let ipiv = A::lu(a.layout()?, a.as_allocated_mut()?)?;
337        Ok(LUFactorized { a, ipiv })
338    }
339}
340
341/// An interface for inverting matrix refs.
342pub trait Inverse {
343    type Output;
344    /// Computes the inverse of the matrix.
345    fn inv(&self) -> Result<Self::Output>;
346}
347
348/// An interface for inverting matrices.
349pub trait InverseInto {
350    type Output;
351    /// Computes the inverse of the matrix.
352    fn inv_into(self) -> Result<Self::Output>;
353}
354
355impl<A, S> InverseInto for LUFactorized<S>
356where
357    A: Scalar + Lapack,
358    S: DataMut<Elem = A> + RawDataClone,
359{
360    type Output = ArrayBase<S, Ix2>;
361
362    fn inv_into(mut self) -> Result<ArrayBase<S, Ix2>> {
363        A::inv(
364            self.a.square_layout()?,
365            self.a.as_allocated_mut()?,
366            &self.ipiv,
367        )?;
368        Ok(self.a)
369    }
370}
371
372impl<A, S> Inverse for LUFactorized<S>
373where
374    A: Scalar + Lapack,
375    S: Data<Elem = A> + RawDataClone,
376{
377    type Output = Array2<A>;
378
379    fn inv(&self) -> Result<Array2<A>> {
380        // Preserve the existing layout. This is required to obtain the correct
381        // result, because the result of `A::inv` is layout-dependent.
382        let a = if self.a.is_standard_layout() {
383            replicate(&self.a)
384        } else {
385            replicate(&self.a.t()).reversed_axes()
386        };
387        let f = LUFactorized {
388            a,
389            ipiv: self.ipiv.clone(),
390        };
391        f.inv_into()
392    }
393}
394
395impl<A, S> InverseInto for ArrayBase<S, Ix2>
396where
397    A: Scalar + Lapack,
398    S: DataMut<Elem = A> + RawDataClone,
399{
400    type Output = Self;
401
402    fn inv_into(self) -> Result<Self::Output> {
403        let f = self.factorize_into()?;
404        f.inv_into()
405    }
406}
407
408impl<A, Si> Inverse for ArrayBase<Si, Ix2>
409where
410    A: Scalar + Lapack,
411    Si: Data<Elem = A>,
412{
413    type Output = Array2<A>;
414
415    fn inv(&self) -> Result<Self::Output> {
416        let f = self.factorize()?;
417        f.inv_into()
418    }
419}
420
421/// An interface for calculating determinants of matrix refs.
422pub trait Determinant<A: Scalar> {
423    /// Computes the determinant of the matrix.
424    fn det(&self) -> Result<A> {
425        let (sign, ln_det) = self.sln_det()?;
426        Ok(sign * A::from_real(Float::exp(ln_det)))
427    }
428
429    /// Computes the `(sign, natural_log)` of the determinant of the matrix.
430    ///
431    /// For real matrices, `sign` is `1`, `0`, or `-1`. For complex matrices,
432    /// `sign` is `0` or a complex number with absolute value 1. The
433    /// `natural_log` is the natural logarithm of the absolute value of the
434    /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
435    /// is negative infinity.
436    ///
437    /// To obtain the determinant, you can compute `sign * natural_log.exp()`
438    /// or just call `.det()` instead.
439    ///
440    /// This method is more robust than `.det()` to very small or very large
441    /// determinants since it returns the natural logarithm of the determinant
442    /// rather than the determinant itself.
443    fn sln_det(&self) -> Result<(A, A::Real)>;
444}
445
446/// An interface for calculating determinants of matrices.
447pub trait DeterminantInto<A: Scalar>: Sized {
448    /// Computes the determinant of the matrix.
449    fn det_into(self) -> Result<A> {
450        let (sign, ln_det) = self.sln_det_into()?;
451        Ok(sign * A::from_real(Float::exp(ln_det)))
452    }
453
454    /// Computes the `(sign, natural_log)` of the determinant of the matrix.
455    ///
456    /// For real matrices, `sign` is `1`, `0`, or `-1`. For complex matrices,
457    /// `sign` is `0` or a complex number with absolute value 1. The
458    /// `natural_log` is the natural logarithm of the absolute value of the
459    /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
460    /// is negative infinity.
461    ///
462    /// To obtain the determinant, you can compute `sign * natural_log.exp()`
463    /// or just call `.det_into()` instead.
464    ///
465    /// This method is more robust than `.det()` to very small or very large
466    /// determinants since it returns the natural logarithm of the determinant
467    /// rather than the determinant itself.
468    fn sln_det_into(self) -> Result<(A, A::Real)>;
469}
470
471fn lu_sln_det<'a, A, P, U>(ipiv_iter: P, u_diag_iter: U) -> (A, A::Real)
472where
473    A: Scalar + Lapack,
474    P: Iterator<Item = i32>,
475    U: Iterator<Item = &'a A>,
476{
477    let pivot_sign = if ipiv_iter
478        .enumerate()
479        .filter(|&(i, pivot)| pivot != i as i32 + 1)
480        .count()
481        % 2
482        == 0
483    {
484        A::one()
485    } else {
486        -A::one()
487    };
488    let (upper_sign, ln_det) = u_diag_iter.fold(
489        (A::one(), A::Real::zero()),
490        |(upper_sign, ln_det), &elem| {
491            let abs_elem: A::Real = elem.abs();
492            (
493                upper_sign * elem / A::from_real(abs_elem),
494                ln_det + Float::ln(abs_elem),
495            )
496        },
497    );
498    (pivot_sign * upper_sign, ln_det)
499}
500
501impl<A, S> Determinant<A> for LUFactorized<S>
502where
503    A: Scalar + Lapack,
504    S: Data<Elem = A> + RawDataClone,
505{
506    fn sln_det(&self) -> Result<(A, A::Real)> {
507        self.a.ensure_square()?;
508        Ok(lu_sln_det(self.ipiv.iter().cloned(), self.a.diag().iter()))
509    }
510}
511
512impl<A, S> DeterminantInto<A> for LUFactorized<S>
513where
514    A: Scalar + Lapack,
515    S: Data<Elem = A> + RawDataClone,
516{
517    fn sln_det_into(self) -> Result<(A, A::Real)> {
518        self.a.ensure_square()?;
519        Ok(lu_sln_det(self.ipiv.into_iter(), self.a.into_diag().iter()))
520    }
521}
522
523impl<A, S> Determinant<A> for ArrayBase<S, Ix2>
524where
525    A: Scalar + Lapack,
526    S: Data<Elem = A>,
527{
528    fn sln_det(&self) -> Result<(A, A::Real)> {
529        self.ensure_square()?;
530        match self.factorize() {
531            Ok(fac) => fac.sln_det(),
532            Err(LinalgError::Lapack(e))
533                if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
534            {
535                // The determinant is zero.
536                Ok((A::zero(), A::Real::neg_infinity()))
537            }
538            Err(err) => Err(err),
539        }
540    }
541}
542
543impl<A, S> DeterminantInto<A> for ArrayBase<S, Ix2>
544where
545    A: Scalar + Lapack,
546    S: DataMut<Elem = A> + RawDataClone,
547{
548    fn sln_det_into(self) -> Result<(A, A::Real)> {
549        self.ensure_square()?;
550        match self.factorize_into() {
551            Ok(fac) => fac.sln_det_into(),
552            Err(LinalgError::Lapack(e))
553                if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
554            {
555                // The determinant is zero.
556                Ok((A::zero(), A::Real::neg_infinity()))
557            }
558            Err(err) => Err(err),
559        }
560    }
561}
562
563/// An interface for *estimating* the reciprocal condition number of matrix refs.
564pub trait ReciprocalConditionNum<A: Scalar> {
565    /// *Estimates* the reciprocal of the condition number of the matrix in
566    /// 1-norm.
567    ///
568    /// This method uses the LAPACK `*gecon` routines, which *estimate*
569    /// `self.inv().opnorm_one()` and then compute `rcond = 1. /
570    /// (self.opnorm_one() * self.inv().opnorm_one())`.
571    ///
572    /// * If `rcond` is near `0.`, the matrix is badly conditioned.
573    /// * If `rcond` is near `1.`, the matrix is well conditioned.
574    fn rcond(&self) -> Result<A::Real>;
575}
576
577/// An interface for *estimating* the reciprocal condition number of matrices.
578pub trait ReciprocalConditionNumInto<A: Scalar> {
579    /// *Estimates* the reciprocal of the condition number of the matrix in
580    /// 1-norm.
581    ///
582    /// This method uses the LAPACK `*gecon` routines, which *estimate*
583    /// `self.inv().opnorm_one()` and then compute `rcond = 1. /
584    /// (self.opnorm_one() * self.inv().opnorm_one())`.
585    ///
586    /// * If `rcond` is near `0.`, the matrix is badly conditioned.
587    /// * If `rcond` is near `1.`, the matrix is well conditioned.
588    fn rcond_into(self) -> Result<A::Real>;
589}
590
591impl<A, S> ReciprocalConditionNum<A> for LUFactorized<S>
592where
593    A: Scalar + Lapack,
594    S: Data<Elem = A> + RawDataClone,
595{
596    fn rcond(&self) -> Result<A::Real> {
597        Ok(A::rcond(
598            self.a.layout()?,
599            self.a.as_allocated()?,
600            self.a.opnorm_one()?,
601        )?)
602    }
603}
604
605impl<A, S> ReciprocalConditionNumInto<A> for LUFactorized<S>
606where
607    A: Scalar + Lapack,
608    S: Data<Elem = A> + RawDataClone,
609{
610    fn rcond_into(self) -> Result<A::Real> {
611        self.rcond()
612    }
613}
614
615impl<A, S> ReciprocalConditionNum<A> for ArrayBase<S, Ix2>
616where
617    A: Scalar + Lapack,
618    S: Data<Elem = A>,
619{
620    fn rcond(&self) -> Result<A::Real> {
621        self.factorize()?.rcond_into()
622    }
623}
624
625impl<A, S> ReciprocalConditionNumInto<A> for ArrayBase<S, Ix2>
626where
627    A: Scalar + Lapack,
628    S: DataMut<Elem = A> + RawDataClone,
629{
630    fn rcond_into(self) -> Result<A::Real> {
631        self.factorize_into()?.rcond_into()
632    }
633}