ndarray_linalg/solve.rs
1//! Solve systems of linear equations and invert matrices
2//!
3//! # Examples
4//!
5//! Solve `A * x = b`:
6//!
7//! ```
8//! use ndarray::prelude::*;
9//! use ndarray_linalg::Solve;
10//!
11//! let a: Array2<f64> = array![[3., 2., -1.], [2., -2., 4.], [-2., 1., -2.]];
12//! let b: Array1<f64> = array![1., -2., 0.];
13//! let x = a.solve_into(b).unwrap();
14//! assert!(x.abs_diff_eq(&array![1., -2., -2.], 1e-9));
15//! ```
16//!
17//! There are also special functions for solving `A^T * x = b` and
18//! `A^H * x = b`.
19//!
20//! If you are solving multiple systems of linear equations with the same
21//! coefficient matrix `A`, it's faster to compute the LU factorization once at
22//! the beginning than solving directly using `A`:
23//!
24//! ```
25//! use ndarray::prelude::*;
26//! use ndarray_linalg::*;
27//!
28//! /// Use fixed algorithm and seed of PRNG for reproducible test
29//! let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
30//!
31//! let a: Array2<f64> = random_using((3, 3), &mut rng);
32//! let f = a.factorize_into().unwrap(); // LU factorize A (A is consumed)
33//! for _ in 0..10 {
34//! let b: Array1<f64> = random_using(3, &mut rng);
35//! let x = f.solve_into(b).unwrap(); // Solve A * x = b using factorized L, U
36//! }
37//! ```
38
39use ndarray::*;
40use num_traits::{Float, Zero};
41
42use crate::convert::*;
43use crate::error::*;
44use crate::layout::*;
45use crate::opnorm::OperationNorm;
46use crate::types::*;
47
48pub use lax::{Pivot, Transpose};
49
50/// An interface for solving systems of linear equations.
51///
52/// There are three groups of methods:
53///
54/// * `solve*` (normal) methods solve `A * x = b` for `x`.
55/// * `solve_t*` (transpose) methods solve `A^T * x = b` for `x`.
56/// * `solve_h*` (Hermitian conjugate) methods solve `A^H * x = b` for `x`.
57///
58/// Within each group, there are three methods that handle ownership differently:
59///
60/// * `*` methods take a reference to `b` and return `x` as a new array.
61/// * `*_into` methods take ownership of `b`, store the result in it, and return it.
62/// * `*_inplace` methods take a mutable reference to `b` and store the result in that array.
63///
64/// If you plan to solve many equations with the same `A` matrix but different
65/// `b` vectors, it's faster to factor the `A` matrix once using the
66/// `Factorize` trait, and then solve using the `LUFactorized` struct.
67pub trait Solve<A: Scalar> {
68 /// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
69 /// is the argument, and `x` is the successful result.
70 ///
71 /// # Panics
72 ///
73 /// Panics if the length of `b` is not the equal to the number of columns
74 /// of `A`.
75 fn solve(&self, b: &ArrayRef<A, Ix1>) -> Result<Array1<A>> {
76 let mut b = replicate(b);
77 self.solve_inplace(&mut b)?;
78 Ok(b)
79 }
80
81 /// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
82 /// is the argument, and `x` is the successful result.
83 ///
84 /// # Panics
85 ///
86 /// Panics if the length of `b` is not the equal to the number of columns
87 /// of `A`.
88 fn solve_into<S: DataMut<Elem = A>>(
89 &self,
90 mut b: ArrayBase<S, Ix1>,
91 ) -> Result<ArrayBase<S, Ix1>> {
92 self.solve_inplace(&mut b)?;
93 Ok(b)
94 }
95
96 /// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
97 /// is the argument, and `x` is the successful result.
98 ///
99 /// # Panics
100 ///
101 /// Panics if the length of `b` is not the equal to the number of columns
102 /// of `A`.
103 fn solve_inplace<'a>(&self, b: &'a mut ArrayRef<A, Ix1>) -> Result<&'a mut ArrayRef<A, Ix1>>;
104
105 /// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
106 /// is the argument, and `x` is the successful result.
107 ///
108 /// # Panics
109 ///
110 /// Panics if the length of `b` is not the equal to the number of rows of
111 /// `A`.
112 fn solve_t(&self, b: &ArrayRef<A, Ix1>) -> Result<Array1<A>> {
113 let mut b = replicate(b);
114 self.solve_t_inplace(&mut b)?;
115 Ok(b)
116 }
117
118 /// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
119 /// is the argument, and `x` is the successful result.
120 ///
121 /// # Panics
122 ///
123 /// Panics if the length of `b` is not the equal to the number of rows of
124 /// `A`.
125 fn solve_t_into<S: DataMut<Elem = A>>(
126 &self,
127 mut b: ArrayBase<S, Ix1>,
128 ) -> Result<ArrayBase<S, Ix1>> {
129 self.solve_t_inplace(&mut b)?;
130 Ok(b)
131 }
132
133 /// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
134 /// is the argument, and `x` is the successful result.
135 ///
136 /// # Panics
137 ///
138 /// Panics if the length of `b` is not the equal to the number of rows of
139 /// `A`.
140 fn solve_t_inplace<'a>(&self, b: &'a mut ArrayRef<A, Ix1>) -> Result<&'a mut ArrayRef<A, Ix1>>;
141
142 /// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
143 /// is the argument, and `x` is the successful result.
144 ///
145 /// # Panics
146 ///
147 /// Panics if the length of `b` is not the equal to the number of rows of
148 /// `A`.
149 fn solve_h(&self, b: &ArrayRef<A, Ix1>) -> Result<Array1<A>> {
150 let mut b = replicate(b);
151 self.solve_h_inplace(&mut b)?;
152 Ok(b)
153 }
154 /// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
155 /// is the argument, and `x` is the successful result.
156 ///
157 /// # Panics
158 ///
159 /// Panics if the length of `b` is not the equal to the number of rows of
160 /// `A`.
161 fn solve_h_into<S: DataMut<Elem = A>>(
162 &self,
163 mut b: ArrayBase<S, Ix1>,
164 ) -> Result<ArrayBase<S, Ix1>> {
165 self.solve_h_inplace(&mut b)?;
166 Ok(b)
167 }
168 /// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
169 /// is the argument, and `x` is the successful result.
170 ///
171 /// # Panics
172 ///
173 /// Panics if the length of `b` is not the equal to the number of rows of
174 /// `A`.
175 fn solve_h_inplace<'a>(&self, b: &'a mut ArrayRef<A, Ix1>) -> Result<&'a mut ArrayRef<A, Ix1>>;
176}
177
178/// Represents the LU factorization of a matrix `A` as `A = P*L*U`.
179#[derive(Clone)]
180pub struct LUFactorized<S: Data + RawDataClone> {
181 /// The factors `L` and `U`; the unit diagonal elements of `L` are not
182 /// stored.
183 a: ArrayBase<S, Ix2>,
184 /// The pivot indices that define the permutation matrix `P`.
185 ipiv: Pivot,
186}
187
188impl<A, S> Solve<A> for LUFactorized<S>
189where
190 A: Scalar + Lapack,
191 S: Data<Elem = A> + RawDataClone,
192{
193 fn solve_inplace<'a>(&self, rhs: &'a mut ArrayRef<A, Ix1>) -> Result<&'a mut ArrayRef<A, Ix1>> {
194 assert_eq!(
195 rhs.len(),
196 self.a.len_of(Axis(1)),
197 "The length of `rhs` must be compatible with the shape of the factored matrix.",
198 );
199 A::solve(
200 self.a.square_layout()?,
201 Transpose::No,
202 self.a.as_allocated()?,
203 &self.ipiv,
204 rhs.as_slice_mut().unwrap(),
205 )?;
206 Ok(rhs)
207 }
208 fn solve_t_inplace<'a>(
209 &self,
210 rhs: &'a mut ArrayRef<A, Ix1>,
211 ) -> Result<&'a mut ArrayRef<A, Ix1>> {
212 assert_eq!(
213 rhs.len(),
214 self.a.len_of(Axis(0)),
215 "The length of `rhs` must be compatible with the shape of the factored matrix.",
216 );
217 A::solve(
218 self.a.square_layout()?,
219 Transpose::Transpose,
220 self.a.as_allocated()?,
221 &self.ipiv,
222 rhs.as_slice_mut().unwrap(),
223 )?;
224 Ok(rhs)
225 }
226 fn solve_h_inplace<'a>(
227 &self,
228 rhs: &'a mut ArrayRef<A, Ix1>,
229 ) -> Result<&'a mut ArrayRef<A, Ix1>> {
230 assert_eq!(
231 rhs.len(),
232 self.a.len_of(Axis(0)),
233 "The length of `rhs` must be compatible with the shape of the factored matrix.",
234 );
235 A::solve(
236 self.a.square_layout()?,
237 Transpose::Hermite,
238 self.a.as_allocated()?,
239 &self.ipiv,
240 rhs.as_slice_mut().unwrap(),
241 )?;
242 Ok(rhs)
243 }
244}
245
246impl<A> Solve<A> for ArrayRef<A, Ix2>
247where
248 A: Scalar + Lapack,
249{
250 fn solve_inplace<'a>(&self, rhs: &'a mut ArrayRef<A, Ix1>) -> Result<&'a mut ArrayRef<A, Ix1>> {
251 let f = self.factorize()?;
252 f.solve_inplace(rhs)
253 }
254 fn solve_t_inplace<'a>(
255 &self,
256 rhs: &'a mut ArrayRef<A, Ix1>,
257 ) -> Result<&'a mut ArrayRef<A, Ix1>> {
258 let f = self.factorize()?;
259 f.solve_t_inplace(rhs)
260 }
261 fn solve_h_inplace<'a>(
262 &self,
263 rhs: &'a mut ArrayRef<A, Ix1>,
264 ) -> Result<&'a mut ArrayRef<A, Ix1>> {
265 let f = self.factorize()?;
266 f.solve_h_inplace(rhs)
267 }
268}
269
270/// An interface for computing LU factorizations of matrix refs.
271pub trait Factorize<S: Data + RawDataClone> {
272 /// Computes the LU factorization `A = P*L*U`, where `P` is a permutation
273 /// matrix.
274 fn factorize(&self) -> Result<LUFactorized<S>>;
275}
276
277/// An interface for computing LU factorizations of matrices.
278pub trait FactorizeInto<S: Data + RawDataClone> {
279 /// Computes the LU factorization `A = P*L*U`, where `P` is a permutation
280 /// matrix.
281 fn factorize_into(self) -> Result<LUFactorized<S>>;
282}
283
284impl<A, S> FactorizeInto<S> for ArrayBase<S, Ix2>
285where
286 A: Scalar + Lapack,
287 S: DataMut<Elem = A> + RawDataClone,
288{
289 fn factorize_into(mut self) -> Result<LUFactorized<S>> {
290 let ipiv = A::lu(self.layout()?, self.as_allocated_mut()?)?;
291 Ok(LUFactorized { a: self, ipiv })
292 }
293}
294
295impl<A> Factorize<OwnedRepr<A>> for ArrayRef<A, Ix2>
296where
297 A: Scalar + Lapack,
298{
299 fn factorize(&self) -> Result<LUFactorized<OwnedRepr<A>>> {
300 let mut a: Array2<A> = replicate(self);
301 let ipiv = A::lu(a.layout()?, a.as_allocated_mut()?)?;
302 Ok(LUFactorized { a, ipiv })
303 }
304}
305
306/// An interface for inverting matrix refs.
307pub trait Inverse {
308 type Output;
309 /// Computes the inverse of the matrix.
310 fn inv(&self) -> Result<Self::Output>;
311}
312
313/// An interface for inverting matrices.
314pub trait InverseInto {
315 type Output;
316 /// Computes the inverse of the matrix.
317 fn inv_into(self) -> Result<Self::Output>;
318}
319
320impl<A, S> InverseInto for LUFactorized<S>
321where
322 A: Scalar + Lapack,
323 S: DataMut<Elem = A> + RawDataClone,
324{
325 type Output = ArrayBase<S, Ix2>;
326
327 fn inv_into(mut self) -> Result<ArrayBase<S, Ix2>> {
328 A::inv(
329 self.a.square_layout()?,
330 self.a.as_allocated_mut()?,
331 &self.ipiv,
332 )?;
333 Ok(self.a)
334 }
335}
336
337impl<A, S> Inverse for LUFactorized<S>
338where
339 A: Scalar + Lapack,
340 S: Data<Elem = A> + RawDataClone,
341{
342 type Output = Array2<A>;
343
344 fn inv(&self) -> Result<Array2<A>> {
345 // Preserve the existing layout. This is required to obtain the correct
346 // result, because the result of `A::inv` is layout-dependent.
347 let a = if self.a.is_standard_layout() {
348 replicate(&self.a)
349 } else {
350 replicate(&self.a.t()).reversed_axes()
351 };
352 let f = LUFactorized {
353 a,
354 ipiv: self.ipiv.clone(),
355 };
356 f.inv_into()
357 }
358}
359
360impl<A, S> InverseInto for ArrayBase<S, Ix2>
361where
362 A: Scalar + Lapack,
363 S: DataMut<Elem = A> + RawDataClone,
364{
365 type Output = Self;
366
367 fn inv_into(self) -> Result<Self::Output> {
368 let f = self.factorize_into()?;
369 f.inv_into()
370 }
371}
372
373impl<A> Inverse for ArrayRef<A, Ix2>
374where
375 A: Scalar + Lapack,
376{
377 type Output = Array2<A>;
378
379 fn inv(&self) -> Result<Self::Output> {
380 let f = self.factorize()?;
381 f.inv_into()
382 }
383}
384
385/// An interface for calculating determinants of matrix refs.
386pub trait Determinant<A: Scalar> {
387 /// Computes the determinant of the matrix.
388 fn det(&self) -> Result<A> {
389 let (sign, ln_det) = self.sln_det()?;
390 Ok(sign * A::from_real(Float::exp(ln_det)))
391 }
392
393 /// Computes the `(sign, natural_log)` of the determinant of the matrix.
394 ///
395 /// For real matrices, `sign` is `1`, `0`, or `-1`. For complex matrices,
396 /// `sign` is `0` or a complex number with absolute value 1. The
397 /// `natural_log` is the natural logarithm of the absolute value of the
398 /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
399 /// is negative infinity.
400 ///
401 /// To obtain the determinant, you can compute `sign * natural_log.exp()`
402 /// or just call `.det()` instead.
403 ///
404 /// This method is more robust than `.det()` to very small or very large
405 /// determinants since it returns the natural logarithm of the determinant
406 /// rather than the determinant itself.
407 fn sln_det(&self) -> Result<(A, A::Real)>;
408}
409
410/// An interface for calculating determinants of matrices.
411pub trait DeterminantInto<A: Scalar>: Sized {
412 /// Computes the determinant of the matrix.
413 fn det_into(self) -> Result<A> {
414 let (sign, ln_det) = self.sln_det_into()?;
415 Ok(sign * A::from_real(Float::exp(ln_det)))
416 }
417
418 /// Computes the `(sign, natural_log)` of the determinant of the matrix.
419 ///
420 /// For real matrices, `sign` is `1`, `0`, or `-1`. For complex matrices,
421 /// `sign` is `0` or a complex number with absolute value 1. The
422 /// `natural_log` is the natural logarithm of the absolute value of the
423 /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
424 /// is negative infinity.
425 ///
426 /// To obtain the determinant, you can compute `sign * natural_log.exp()`
427 /// or just call `.det_into()` instead.
428 ///
429 /// This method is more robust than `.det()` to very small or very large
430 /// determinants since it returns the natural logarithm of the determinant
431 /// rather than the determinant itself.
432 fn sln_det_into(self) -> Result<(A, A::Real)>;
433}
434
435fn lu_sln_det<'a, A, P, U>(ipiv_iter: P, u_diag_iter: U) -> (A, A::Real)
436where
437 A: Scalar + Lapack,
438 P: Iterator<Item = i32>,
439 U: Iterator<Item = &'a A>,
440{
441 let pivot_sign = if ipiv_iter
442 .enumerate()
443 .filter(|&(i, pivot)| pivot != i as i32 + 1)
444 .count()
445 % 2
446 == 0
447 {
448 A::one()
449 } else {
450 -A::one()
451 };
452 let (upper_sign, ln_det) = u_diag_iter.fold(
453 (A::one(), A::Real::zero()),
454 |(upper_sign, ln_det), &elem| {
455 let abs_elem: A::Real = elem.abs();
456 (
457 upper_sign * elem / A::from_real(abs_elem),
458 ln_det + Float::ln(abs_elem),
459 )
460 },
461 );
462 (pivot_sign * upper_sign, ln_det)
463}
464
465impl<A, S> Determinant<A> for LUFactorized<S>
466where
467 A: Scalar + Lapack,
468 S: Data<Elem = A> + RawDataClone,
469{
470 fn sln_det(&self) -> Result<(A, A::Real)> {
471 self.a.ensure_square()?;
472 Ok(lu_sln_det(self.ipiv.iter().cloned(), self.a.diag().iter()))
473 }
474}
475
476impl<A, S> DeterminantInto<A> for LUFactorized<S>
477where
478 A: Scalar + Lapack,
479 S: Data<Elem = A> + RawDataClone,
480{
481 fn sln_det_into(self) -> Result<(A, A::Real)> {
482 self.a.ensure_square()?;
483 Ok(lu_sln_det(self.ipiv.into_iter(), self.a.into_diag().iter()))
484 }
485}
486
487impl<A> Determinant<A> for ArrayRef<A, Ix2>
488where
489 A: Scalar + Lapack,
490{
491 fn sln_det(&self) -> Result<(A, A::Real)> {
492 self.ensure_square()?;
493 match self.factorize() {
494 Ok(fac) => fac.sln_det(),
495 Err(LinalgError::Lapack(e))
496 if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
497 {
498 // The determinant is zero.
499 Ok((A::zero(), A::Real::neg_infinity()))
500 }
501 Err(err) => Err(err),
502 }
503 }
504}
505
506impl<A, S> DeterminantInto<A> for ArrayBase<S, Ix2>
507where
508 A: Scalar + Lapack,
509 S: DataMut<Elem = A> + RawDataClone,
510{
511 fn sln_det_into(self) -> Result<(A, A::Real)> {
512 self.ensure_square()?;
513 match self.factorize_into() {
514 Ok(fac) => fac.sln_det_into(),
515 Err(LinalgError::Lapack(e))
516 if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
517 {
518 // The determinant is zero.
519 Ok((A::zero(), A::Real::neg_infinity()))
520 }
521 Err(err) => Err(err),
522 }
523 }
524}
525
526/// An interface for *estimating* the reciprocal condition number of matrix refs.
527pub trait ReciprocalConditionNum<A: Scalar> {
528 /// *Estimates* the reciprocal of the condition number of the matrix in
529 /// 1-norm.
530 ///
531 /// This method uses the LAPACK `*gecon` routines, which *estimate*
532 /// `self.inv().opnorm_one()` and then compute `rcond = 1. /
533 /// (self.opnorm_one() * self.inv().opnorm_one())`.
534 ///
535 /// * If `rcond` is near `0.`, the matrix is badly conditioned.
536 /// * If `rcond` is near `1.`, the matrix is well conditioned.
537 fn rcond(&self) -> Result<A::Real>;
538}
539
540/// An interface for *estimating* the reciprocal condition number of matrices.
541pub trait ReciprocalConditionNumInto<A: Scalar> {
542 /// *Estimates* the reciprocal of the condition number of the matrix in
543 /// 1-norm.
544 ///
545 /// This method uses the LAPACK `*gecon` routines, which *estimate*
546 /// `self.inv().opnorm_one()` and then compute `rcond = 1. /
547 /// (self.opnorm_one() * self.inv().opnorm_one())`.
548 ///
549 /// * If `rcond` is near `0.`, the matrix is badly conditioned.
550 /// * If `rcond` is near `1.`, the matrix is well conditioned.
551 fn rcond_into(self) -> Result<A::Real>;
552}
553
554impl<A, S> ReciprocalConditionNum<A> for LUFactorized<S>
555where
556 A: Scalar + Lapack,
557 S: Data<Elem = A> + RawDataClone,
558{
559 fn rcond(&self) -> Result<A::Real> {
560 Ok(A::rcond(
561 self.a.layout()?,
562 self.a.as_allocated()?,
563 self.a.opnorm_one()?,
564 )?)
565 }
566}
567
568impl<A, S> ReciprocalConditionNumInto<A> for LUFactorized<S>
569where
570 A: Scalar + Lapack,
571 S: Data<Elem = A> + RawDataClone,
572{
573 fn rcond_into(self) -> Result<A::Real> {
574 self.rcond()
575 }
576}
577
578impl<A> ReciprocalConditionNum<A> for ArrayRef<A, Ix2>
579where
580 A: Scalar + Lapack,
581{
582 fn rcond(&self) -> Result<A::Real> {
583 self.factorize()?.rcond_into()
584 }
585}
586
587impl<A, S> ReciprocalConditionNumInto<A> for ArrayBase<S, Ix2>
588where
589 A: Scalar + Lapack,
590 S: DataMut<Elem = A> + RawDataClone,
591{
592 fn rcond_into(self) -> Result<A::Real> {
593 self.factorize_into()?.rcond_into()
594 }
595}