ndarray_linalg/
solveh.rs

1//! Solve Hermitian (or real symmetric) linear problems and invert Hermitian
2//! (or real symmetric) matrices
3//!
4//! **Note that only the upper triangular portion of the matrix is used.**
5//!
6//! # Examples
7//!
8//! Solve `A * x = b`, where `A` is a Hermitian (or real symmetric) matrix:
9//!
10//! ```
11//! use ndarray::prelude::*;
12//! use ndarray_linalg::SolveH;
13//!
14//! let a: Array2<f64> = array![
15//!     [3., 2., -1.],
16//!     [2., -2., 4.],
17//!     [-1., 4., 5.]
18//! ];
19//! let b: Array1<f64> = array![11., -12., 1.];
20//! let x = a.solveh_into(b).unwrap();
21//! assert!(x.abs_diff_eq(&array![1., 3., -2.], 1e-9));
22//! ```
23//!
24//! If you are solving multiple systems of linear equations with the same
25//! Hermitian or real symmetric coefficient matrix `A`, it's faster to compute
26//! the factorization once at the beginning than solving directly using `A`:
27//!
28//! ```
29//! use ndarray::prelude::*;
30//! use ndarray_linalg::*;
31//!
32//! /// Use fixed algorithm and seed of PRNG for reproducible test
33//! let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
34//!
35//! let a: Array2<f64> = random_using((3, 3), &mut rng);
36//! let f = a.factorizeh_into().unwrap(); // Factorize A (A is consumed)
37//! for _ in 0..10 {
38//!     let b: Array1<f64> = random_using(3, &mut rng);
39//!     let x = f.solveh_into(b).unwrap(); // Solve A * x = b using the factorization
40//! }
41//! ```
42
43use ndarray::*;
44use num_traits::{Float, One, Zero};
45
46use crate::convert::*;
47use crate::error::*;
48use crate::layout::*;
49use crate::types::*;
50
51pub use lax::{Pivot, UPLO};
52
53/// An interface for solving systems of Hermitian (or real symmetric) linear equations.
54///
55/// If you plan to solve many equations with the same Hermitian (or real
56/// symmetric) coefficient matrix `A` but different `b` vectors, it's faster to
57/// factor the `A` matrix once using the `FactorizeH` trait, and then solve
58/// using the `BKFactorized` struct.
59pub trait SolveH<A: Scalar> {
60    /// Solves a system of linear equations `A * x = b` with Hermitian (or real
61    /// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
62    /// `x` is the successful result.
63    ///
64    /// # Panics
65    ///
66    /// Panics if the length of `b` is not the equal to the number of columns
67    /// of `A`.
68    fn solveh(&self, b: &ArrayRef<A, Ix1>) -> Result<Array1<A>> {
69        let mut b = replicate(b);
70        self.solveh_inplace(&mut b)?;
71        Ok(b)
72    }
73
74    /// Solves a system of linear equations `A * x = b` with Hermitian (or real
75    /// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
76    /// `x` is the successful result.
77    ///
78    /// # Panics
79    ///
80    /// Panics if the length of `b` is not the equal to the number of columns
81    /// of `A`.
82    fn solveh_into<S: DataMut<Elem = A>>(
83        &self,
84        mut b: ArrayBase<S, Ix1>,
85    ) -> Result<ArrayBase<S, Ix1>> {
86        self.solveh_inplace(&mut b)?;
87        Ok(b)
88    }
89
90    /// Solves a system of linear equations `A * x = b` with Hermitian (or real
91    /// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
92    /// `x` is the successful result. The value of `x` is also assigned to the
93    /// argument.
94    ///
95    /// # Panics
96    ///
97    /// Panics if the length of `b` is not the equal to the number of columns
98    /// of `A`.
99    fn solveh_inplace<'a>(&self, b: &'a mut ArrayRef<A, Ix1>) -> Result<&'a mut ArrayRef<A, Ix1>>;
100}
101
102/// Represents the Bunch–Kaufman factorization of a Hermitian (or real
103/// symmetric) matrix as `A = P * U * D * U^H * P^T`.
104pub struct BKFactorized<S: Data> {
105    pub a: ArrayBase<S, Ix2>,
106    pub ipiv: Pivot,
107}
108
109impl<A, S> SolveH<A> for BKFactorized<S>
110where
111    A: Scalar + Lapack,
112    S: Data<Elem = A>,
113{
114    fn solveh_inplace<'a>(
115        &self,
116        rhs: &'a mut ArrayRef<A, Ix1>,
117    ) -> Result<&'a mut ArrayRef<A, Ix1>> {
118        assert_eq!(
119            rhs.len(),
120            self.a.len_of(Axis(1)),
121            "The length of `rhs` must be compatible with the shape of the factored matrix.",
122        );
123        A::solveh(
124            self.a.square_layout()?,
125            UPLO::Upper,
126            self.a.as_allocated()?,
127            &self.ipiv,
128            rhs.as_slice_mut().unwrap(),
129        )?;
130        Ok(rhs)
131    }
132}
133
134impl<A> SolveH<A> for ArrayRef<A, Ix2>
135where
136    A: Scalar + Lapack,
137{
138    fn solveh_inplace<'a>(
139        &self,
140        rhs: &'a mut ArrayRef<A, Ix1>,
141    ) -> Result<&'a mut ArrayRef<A, Ix1>> {
142        let f = self.factorizeh()?;
143        f.solveh_inplace(rhs)
144    }
145}
146
147/// An interface for computing the Bunch–Kaufman factorization of Hermitian (or
148/// real symmetric) matrix refs.
149pub trait FactorizeH<S: Data> {
150    /// Computes the Bunch–Kaufman factorization of a Hermitian (or real
151    /// symmetric) matrix.
152    fn factorizeh(&self) -> Result<BKFactorized<S>>;
153}
154
155/// An interface for computing the Bunch–Kaufman factorization of Hermitian (or
156/// real symmetric) matrices.
157pub trait FactorizeHInto<S: Data> {
158    /// Computes the Bunch–Kaufman factorization of a Hermitian (or real
159    /// symmetric) matrix.
160    fn factorizeh_into(self) -> Result<BKFactorized<S>>;
161}
162
163impl<A, S> FactorizeHInto<S> for ArrayBase<S, Ix2>
164where
165    A: Scalar + Lapack,
166    S: DataMut<Elem = A>,
167{
168    fn factorizeh_into(mut self) -> Result<BKFactorized<S>> {
169        let ipiv = A::bk(self.square_layout()?, UPLO::Upper, self.as_allocated_mut()?)?;
170        Ok(BKFactorized { a: self, ipiv })
171    }
172}
173
174impl<A> FactorizeH<OwnedRepr<A>> for ArrayRef<A, Ix2>
175where
176    A: Scalar + Lapack,
177{
178    fn factorizeh(&self) -> Result<BKFactorized<OwnedRepr<A>>> {
179        let mut a: Array2<A> = replicate(self);
180        let ipiv = A::bk(a.square_layout()?, UPLO::Upper, a.as_allocated_mut()?)?;
181        Ok(BKFactorized { a, ipiv })
182    }
183}
184
185/// An interface for inverting Hermitian (or real symmetric) matrix refs.
186pub trait InverseH {
187    type Output;
188    /// Computes the inverse of the Hermitian (or real symmetric) matrix.
189    fn invh(&self) -> Result<Self::Output>;
190}
191
192/// An interface for inverting Hermitian (or real symmetric) matrices.
193pub trait InverseHInto {
194    type Output;
195    /// Computes the inverse of the Hermitian (or real symmetric) matrix.
196    fn invh_into(self) -> Result<Self::Output>;
197}
198
199impl<A, S> InverseHInto for BKFactorized<S>
200where
201    A: Scalar + Lapack,
202    S: DataMut<Elem = A>,
203{
204    type Output = ArrayBase<S, Ix2>;
205
206    fn invh_into(mut self) -> Result<ArrayBase<S, Ix2>> {
207        A::invh(
208            self.a.square_layout()?,
209            UPLO::Upper,
210            self.a.as_allocated_mut()?,
211            &self.ipiv,
212        )?;
213        triangular_fill_hermitian(&mut self.a, UPLO::Upper);
214        Ok(self.a)
215    }
216}
217
218impl<A, S> InverseH for BKFactorized<S>
219where
220    A: Scalar + Lapack,
221    S: Data<Elem = A>,
222{
223    type Output = Array2<A>;
224
225    fn invh(&self) -> Result<Self::Output> {
226        let f = BKFactorized {
227            a: replicate(&self.a),
228            ipiv: self.ipiv.clone(),
229        };
230        f.invh_into()
231    }
232}
233
234impl<A, S> InverseHInto for ArrayBase<S, Ix2>
235where
236    A: Scalar + Lapack,
237    S: DataMut<Elem = A>,
238{
239    type Output = Self;
240
241    fn invh_into(self) -> Result<Self::Output> {
242        let f = self.factorizeh_into()?;
243        f.invh_into()
244    }
245}
246
247impl<A> InverseH for ArrayRef<A, Ix2>
248where
249    A: Scalar + Lapack,
250{
251    type Output = Array2<A>;
252
253    fn invh(&self) -> Result<Self::Output> {
254        let f = self.factorizeh()?;
255        f.invh_into()
256    }
257}
258
259/// An interface for calculating determinants of Hermitian (or real symmetric) matrix refs.
260pub trait DeterminantH {
261    /// The element type of the matrix.
262    type Elem: Scalar;
263
264    /// Computes the determinant of the Hermitian (or real symmetric) matrix.
265    fn deth(&self) -> Result<<Self::Elem as Scalar>::Real>;
266
267    /// Computes the `(sign, natural_log)` of the determinant of the Hermitian
268    /// (or real symmetric) matrix.
269    ///
270    /// The `natural_log` is the natural logarithm of the absolute value of the
271    /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
272    /// is negative infinity.
273    ///
274    /// To obtain the determinant, you can compute `sign * natural_log.exp()`
275    /// or just call `.deth()` instead.
276    ///
277    /// This method is more robust than `.deth()` to very small or very large
278    /// determinants since it returns the natural logarithm of the determinant
279    /// rather than the determinant itself.
280    fn sln_deth(&self) -> Result<(<Self::Elem as Scalar>::Real, <Self::Elem as Scalar>::Real)>;
281}
282
283/// An interface for calculating determinants of Hermitian (or real symmetric) matrices.
284pub trait DeterminantHInto {
285    /// The element type of the matrix.
286    type Elem: Scalar;
287
288    /// Computes the determinant of the Hermitian (or real symmetric) matrix.
289    fn deth_into(self) -> Result<<Self::Elem as Scalar>::Real>;
290
291    /// Computes the `(sign, natural_log)` of the determinant of the Hermitian
292    /// (or real symmetric) matrix.
293    ///
294    /// The `natural_log` is the natural logarithm of the absolute value of the
295    /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
296    /// is negative infinity.
297    ///
298    /// To obtain the determinant, you can compute `sign * natural_log.exp()`
299    /// or just call `.deth_into()` instead.
300    ///
301    /// This method is more robust than `.deth_into()` to very small or very
302    /// large determinants since it returns the natural logarithm of the
303    /// determinant rather than the determinant itself.
304    fn sln_deth_into(self) -> Result<(<Self::Elem as Scalar>::Real, <Self::Elem as Scalar>::Real)>;
305}
306
307/// Returns the sign and natural log of the determinant.
308fn bk_sln_det<P, A>(uplo: UPLO, ipiv_iter: P, a: &ArrayRef<A, Ix2>) -> (A::Real, A::Real)
309where
310    P: Iterator<Item = i32>,
311    A: Scalar + Lapack,
312{
313    let layout = a.layout().unwrap();
314    let mut sign = A::Real::one();
315    let mut ln_det = A::Real::zero();
316    let mut ipiv_enum = ipiv_iter.enumerate();
317    while let Some((k, ipiv_k)) = ipiv_enum.next() {
318        debug_assert!(k < a.nrows() && k < a.ncols());
319        if ipiv_k > 0 {
320            // 1x1 block at k, must be real.
321            let elem = unsafe { a.uget((k, k)) }.re();
322            debug_assert_eq!(elem.im(), Zero::zero());
323            sign *= elem.signum();
324            ln_det += Float::ln(Float::abs(elem));
325        } else {
326            // 2x2 block at k..k+2.
327
328            // Upper left diagonal elem, must be real.
329            let upper_diag = unsafe { a.uget((k, k)) }.re();
330            debug_assert_eq!(upper_diag.im(), Zero::zero());
331
332            // Lower right diagonal elem, must be real.
333            let lower_diag = unsafe { a.uget((k + 1, k + 1)) }.re();
334            debug_assert_eq!(lower_diag.im(), Zero::zero());
335
336            // Off-diagonal elements, can be complex.
337            let off_diag = match layout {
338                MatrixLayout::C { .. } => match uplo {
339                    UPLO::Upper => unsafe { a.uget((k + 1, k)) },
340                    UPLO::Lower => unsafe { a.uget((k, k + 1)) },
341                },
342                MatrixLayout::F { .. } => match uplo {
343                    UPLO::Upper => unsafe { a.uget((k, k + 1)) },
344                    UPLO::Lower => unsafe { a.uget((k + 1, k)) },
345                },
346            };
347
348            // Determinant of 2x2 block.
349            let block_det = upper_diag * lower_diag - off_diag.square();
350            sign *= block_det.signum();
351            ln_det += Float::ln(Float::abs(block_det));
352
353            // Skip the k+1 ipiv value.
354            ipiv_enum.next();
355        }
356    }
357    (sign, ln_det)
358}
359
360impl<A, S> BKFactorized<S>
361where
362    A: Scalar + Lapack,
363    S: Data<Elem = A>,
364{
365    /// Computes the determinant of the factorized Hermitian (or real
366    /// symmetric) matrix.
367    pub fn deth(&self) -> A::Real {
368        let (sign, ln_det) = self.sln_deth();
369        sign * Float::exp(ln_det)
370    }
371
372    /// Computes the `(sign, natural_log)` of the determinant of the factorized
373    /// Hermitian (or real symmetric) matrix.
374    ///
375    /// The `natural_log` is the natural logarithm of the absolute value of the
376    /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
377    /// is negative infinity.
378    ///
379    /// To obtain the determinant, you can compute `sign * natural_log.exp()`
380    /// or just call `.deth()` instead.
381    ///
382    /// This method is more robust than `.deth()` to very small or very large
383    /// determinants since it returns the natural logarithm of the determinant
384    /// rather than the determinant itself.
385    pub fn sln_deth(&self) -> (A::Real, A::Real) {
386        bk_sln_det(UPLO::Upper, self.ipiv.iter().cloned(), &self.a)
387    }
388
389    /// Computes the determinant of the factorized Hermitian (or real
390    /// symmetric) matrix.
391    pub fn deth_into(self) -> A::Real {
392        let (sign, ln_det) = self.sln_deth_into();
393        sign * Float::exp(ln_det)
394    }
395
396    /// Computes the `(sign, natural_log)` of the determinant of the factorized
397    /// Hermitian (or real symmetric) matrix.
398    ///
399    /// The `natural_log` is the natural logarithm of the absolute value of the
400    /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
401    /// is negative infinity.
402    ///
403    /// To obtain the determinant, you can compute `sign * natural_log.exp()`
404    /// or just call `.deth_into()` instead.
405    ///
406    /// This method is more robust than `.deth_into()` to very small or very
407    /// large determinants since it returns the natural logarithm of the
408    /// determinant rather than the determinant itself.
409    pub fn sln_deth_into(self) -> (A::Real, A::Real) {
410        bk_sln_det(UPLO::Upper, self.ipiv.into_iter(), &self.a)
411    }
412}
413
414impl<A> DeterminantH for ArrayRef<A, Ix2>
415where
416    A: Scalar + Lapack,
417{
418    type Elem = A;
419
420    fn deth(&self) -> Result<A::Real> {
421        let (sign, ln_det) = self.sln_deth()?;
422        Ok(sign * Float::exp(ln_det))
423    }
424
425    fn sln_deth(&self) -> Result<(A::Real, A::Real)> {
426        match self.factorizeh() {
427            Ok(fac) => Ok(fac.sln_deth()),
428            Err(LinalgError::Lapack(e))
429                if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
430            {
431                // Determinant is zero.
432                Ok((A::Real::zero(), A::Real::neg_infinity()))
433            }
434            Err(err) => Err(err),
435        }
436    }
437}
438
439impl<A, S> DeterminantHInto for ArrayBase<S, Ix2>
440where
441    A: Scalar + Lapack,
442    S: DataMut<Elem = A>,
443{
444    type Elem = A;
445
446    fn deth_into(self) -> Result<A::Real> {
447        let (sign, ln_det) = self.sln_deth_into()?;
448        Ok(sign * Float::exp(ln_det))
449    }
450
451    fn sln_deth_into(self) -> Result<(A::Real, A::Real)> {
452        match self.factorizeh_into() {
453            Ok(fac) => Ok(fac.sln_deth_into()),
454            Err(LinalgError::Lapack(e))
455                if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
456            {
457                // Determinant is zero.
458                Ok((A::Real::zero(), A::Real::neg_infinity()))
459            }
460            Err(err) => Err(err),
461        }
462    }
463}