ndarray_linalg/
solveh.rs

1//! Solve Hermitian (or real symmetric) linear problems and invert Hermitian
2//! (or real symmetric) matrices
3//!
4//! **Note that only the upper triangular portion of the matrix is used.**
5//!
6//! # Examples
7//!
8//! Solve `A * x = b`, where `A` is a Hermitian (or real symmetric) matrix:
9//!
10//! ```
11//! use ndarray::prelude::*;
12//! use ndarray_linalg::SolveH;
13//!
14//! let a: Array2<f64> = array![
15//!     [3., 2., -1.],
16//!     [2., -2., 4.],
17//!     [-1., 4., 5.]
18//! ];
19//! let b: Array1<f64> = array![11., -12., 1.];
20//! let x = a.solveh_into(b).unwrap();
21//! assert!(x.abs_diff_eq(&array![1., 3., -2.], 1e-9));
22//! ```
23//!
24//! If you are solving multiple systems of linear equations with the same
25//! Hermitian or real symmetric coefficient matrix `A`, it's faster to compute
26//! the factorization once at the beginning than solving directly using `A`:
27//!
28//! ```
29//! use ndarray::prelude::*;
30//! use ndarray_linalg::*;
31//!
32//! /// Use fixed algorithm and seed of PRNG for reproducible test
33//! let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
34//!
35//! let a: Array2<f64> = random_using((3, 3), &mut rng);
36//! let f = a.factorizeh_into().unwrap(); // Factorize A (A is consumed)
37//! for _ in 0..10 {
38//!     let b: Array1<f64> = random_using(3, &mut rng);
39//!     let x = f.solveh_into(b).unwrap(); // Solve A * x = b using the factorization
40//! }
41//! ```
42
43use ndarray::*;
44use num_traits::{Float, One, Zero};
45
46use crate::convert::*;
47use crate::error::*;
48use crate::layout::*;
49use crate::types::*;
50
51pub use lax::{Pivot, UPLO};
52
53/// An interface for solving systems of Hermitian (or real symmetric) linear equations.
54///
55/// If you plan to solve many equations with the same Hermitian (or real
56/// symmetric) coefficient matrix `A` but different `b` vectors, it's faster to
57/// factor the `A` matrix once using the `FactorizeH` trait, and then solve
58/// using the `BKFactorized` struct.
59pub trait SolveH<A: Scalar> {
60    /// Solves a system of linear equations `A * x = b` with Hermitian (or real
61    /// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
62    /// `x` is the successful result.
63    ///
64    /// # Panics
65    ///
66    /// Panics if the length of `b` is not the equal to the number of columns
67    /// of `A`.
68    fn solveh<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
69        let mut b = replicate(b);
70        self.solveh_inplace(&mut b)?;
71        Ok(b)
72    }
73
74    /// Solves a system of linear equations `A * x = b` with Hermitian (or real
75    /// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
76    /// `x` is the successful result.
77    ///
78    /// # Panics
79    ///
80    /// Panics if the length of `b` is not the equal to the number of columns
81    /// of `A`.
82    fn solveh_into<S: DataMut<Elem = A>>(
83        &self,
84        mut b: ArrayBase<S, Ix1>,
85    ) -> Result<ArrayBase<S, Ix1>> {
86        self.solveh_inplace(&mut b)?;
87        Ok(b)
88    }
89
90    /// Solves a system of linear equations `A * x = b` with Hermitian (or real
91    /// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
92    /// `x` is the successful result. The value of `x` is also assigned to the
93    /// argument.
94    ///
95    /// # Panics
96    ///
97    /// Panics if the length of `b` is not the equal to the number of columns
98    /// of `A`.
99    fn solveh_inplace<'a, S: DataMut<Elem = A>>(
100        &self,
101        b: &'a mut ArrayBase<S, Ix1>,
102    ) -> Result<&'a mut ArrayBase<S, Ix1>>;
103}
104
105/// Represents the Bunch–Kaufman factorization of a Hermitian (or real
106/// symmetric) matrix as `A = P * U * D * U^H * P^T`.
107pub struct BKFactorized<S: Data> {
108    pub a: ArrayBase<S, Ix2>,
109    pub ipiv: Pivot,
110}
111
112impl<A, S> SolveH<A> for BKFactorized<S>
113where
114    A: Scalar + Lapack,
115    S: Data<Elem = A>,
116{
117    fn solveh_inplace<'a, Sb>(
118        &self,
119        rhs: &'a mut ArrayBase<Sb, Ix1>,
120    ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
121    where
122        Sb: DataMut<Elem = A>,
123    {
124        assert_eq!(
125            rhs.len(),
126            self.a.len_of(Axis(1)),
127            "The length of `rhs` must be compatible with the shape of the factored matrix.",
128        );
129        A::solveh(
130            self.a.square_layout()?,
131            UPLO::Upper,
132            self.a.as_allocated()?,
133            &self.ipiv,
134            rhs.as_slice_mut().unwrap(),
135        )?;
136        Ok(rhs)
137    }
138}
139
140impl<A, S> SolveH<A> for ArrayBase<S, Ix2>
141where
142    A: Scalar + Lapack,
143    S: Data<Elem = A>,
144{
145    fn solveh_inplace<'a, Sb>(
146        &self,
147        rhs: &'a mut ArrayBase<Sb, Ix1>,
148    ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
149    where
150        Sb: DataMut<Elem = A>,
151    {
152        let f = self.factorizeh()?;
153        f.solveh_inplace(rhs)
154    }
155}
156
157/// An interface for computing the Bunch–Kaufman factorization of Hermitian (or
158/// real symmetric) matrix refs.
159pub trait FactorizeH<S: Data> {
160    /// Computes the Bunch–Kaufman factorization of a Hermitian (or real
161    /// symmetric) matrix.
162    fn factorizeh(&self) -> Result<BKFactorized<S>>;
163}
164
165/// An interface for computing the Bunch–Kaufman factorization of Hermitian (or
166/// real symmetric) matrices.
167pub trait FactorizeHInto<S: Data> {
168    /// Computes the Bunch–Kaufman factorization of a Hermitian (or real
169    /// symmetric) matrix.
170    fn factorizeh_into(self) -> Result<BKFactorized<S>>;
171}
172
173impl<A, S> FactorizeHInto<S> for ArrayBase<S, Ix2>
174where
175    A: Scalar + Lapack,
176    S: DataMut<Elem = A>,
177{
178    fn factorizeh_into(mut self) -> Result<BKFactorized<S>> {
179        let ipiv = A::bk(self.square_layout()?, UPLO::Upper, self.as_allocated_mut()?)?;
180        Ok(BKFactorized { a: self, ipiv })
181    }
182}
183
184impl<A, Si> FactorizeH<OwnedRepr<A>> for ArrayBase<Si, Ix2>
185where
186    A: Scalar + Lapack,
187    Si: Data<Elem = A>,
188{
189    fn factorizeh(&self) -> Result<BKFactorized<OwnedRepr<A>>> {
190        let mut a: Array2<A> = replicate(self);
191        let ipiv = A::bk(a.square_layout()?, UPLO::Upper, a.as_allocated_mut()?)?;
192        Ok(BKFactorized { a, ipiv })
193    }
194}
195
196/// An interface for inverting Hermitian (or real symmetric) matrix refs.
197pub trait InverseH {
198    type Output;
199    /// Computes the inverse of the Hermitian (or real symmetric) matrix.
200    fn invh(&self) -> Result<Self::Output>;
201}
202
203/// An interface for inverting Hermitian (or real symmetric) matrices.
204pub trait InverseHInto {
205    type Output;
206    /// Computes the inverse of the Hermitian (or real symmetric) matrix.
207    fn invh_into(self) -> Result<Self::Output>;
208}
209
210impl<A, S> InverseHInto for BKFactorized<S>
211where
212    A: Scalar + Lapack,
213    S: DataMut<Elem = A>,
214{
215    type Output = ArrayBase<S, Ix2>;
216
217    fn invh_into(mut self) -> Result<ArrayBase<S, Ix2>> {
218        A::invh(
219            self.a.square_layout()?,
220            UPLO::Upper,
221            self.a.as_allocated_mut()?,
222            &self.ipiv,
223        )?;
224        triangular_fill_hermitian(&mut self.a, UPLO::Upper);
225        Ok(self.a)
226    }
227}
228
229impl<A, S> InverseH for BKFactorized<S>
230where
231    A: Scalar + Lapack,
232    S: Data<Elem = A>,
233{
234    type Output = Array2<A>;
235
236    fn invh(&self) -> Result<Self::Output> {
237        let f = BKFactorized {
238            a: replicate(&self.a),
239            ipiv: self.ipiv.clone(),
240        };
241        f.invh_into()
242    }
243}
244
245impl<A, S> InverseHInto for ArrayBase<S, Ix2>
246where
247    A: Scalar + Lapack,
248    S: DataMut<Elem = A>,
249{
250    type Output = Self;
251
252    fn invh_into(self) -> Result<Self::Output> {
253        let f = self.factorizeh_into()?;
254        f.invh_into()
255    }
256}
257
258impl<A, Si> InverseH for ArrayBase<Si, Ix2>
259where
260    A: Scalar + Lapack,
261    Si: Data<Elem = A>,
262{
263    type Output = Array2<A>;
264
265    fn invh(&self) -> Result<Self::Output> {
266        let f = self.factorizeh()?;
267        f.invh_into()
268    }
269}
270
271/// An interface for calculating determinants of Hermitian (or real symmetric) matrix refs.
272pub trait DeterminantH {
273    /// The element type of the matrix.
274    type Elem: Scalar;
275
276    /// Computes the determinant of the Hermitian (or real symmetric) matrix.
277    fn deth(&self) -> Result<<Self::Elem as Scalar>::Real>;
278
279    /// Computes the `(sign, natural_log)` of the determinant of the Hermitian
280    /// (or real symmetric) matrix.
281    ///
282    /// The `natural_log` is the natural logarithm of the absolute value of the
283    /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
284    /// is negative infinity.
285    ///
286    /// To obtain the determinant, you can compute `sign * natural_log.exp()`
287    /// or just call `.deth()` instead.
288    ///
289    /// This method is more robust than `.deth()` to very small or very large
290    /// determinants since it returns the natural logarithm of the determinant
291    /// rather than the determinant itself.
292    fn sln_deth(&self) -> Result<(<Self::Elem as Scalar>::Real, <Self::Elem as Scalar>::Real)>;
293}
294
295/// An interface for calculating determinants of Hermitian (or real symmetric) matrices.
296pub trait DeterminantHInto {
297    /// The element type of the matrix.
298    type Elem: Scalar;
299
300    /// Computes the determinant of the Hermitian (or real symmetric) matrix.
301    fn deth_into(self) -> Result<<Self::Elem as Scalar>::Real>;
302
303    /// Computes the `(sign, natural_log)` of the determinant of the Hermitian
304    /// (or real symmetric) matrix.
305    ///
306    /// The `natural_log` is the natural logarithm of the absolute value of the
307    /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
308    /// is negative infinity.
309    ///
310    /// To obtain the determinant, you can compute `sign * natural_log.exp()`
311    /// or just call `.deth_into()` instead.
312    ///
313    /// This method is more robust than `.deth_into()` to very small or very
314    /// large determinants since it returns the natural logarithm of the
315    /// determinant rather than the determinant itself.
316    fn sln_deth_into(self) -> Result<(<Self::Elem as Scalar>::Real, <Self::Elem as Scalar>::Real)>;
317}
318
319/// Returns the sign and natural log of the determinant.
320fn bk_sln_det<P, S, A>(uplo: UPLO, ipiv_iter: P, a: &ArrayBase<S, Ix2>) -> (A::Real, A::Real)
321where
322    P: Iterator<Item = i32>,
323    S: Data<Elem = A>,
324    A: Scalar + Lapack,
325{
326    let layout = a.layout().unwrap();
327    let mut sign = A::Real::one();
328    let mut ln_det = A::Real::zero();
329    let mut ipiv_enum = ipiv_iter.enumerate();
330    while let Some((k, ipiv_k)) = ipiv_enum.next() {
331        debug_assert!(k < a.nrows() && k < a.ncols());
332        if ipiv_k > 0 {
333            // 1x1 block at k, must be real.
334            let elem = unsafe { a.uget((k, k)) }.re();
335            debug_assert_eq!(elem.im(), Zero::zero());
336            sign *= elem.signum();
337            ln_det += Float::ln(Float::abs(elem));
338        } else {
339            // 2x2 block at k..k+2.
340
341            // Upper left diagonal elem, must be real.
342            let upper_diag = unsafe { a.uget((k, k)) }.re();
343            debug_assert_eq!(upper_diag.im(), Zero::zero());
344
345            // Lower right diagonal elem, must be real.
346            let lower_diag = unsafe { a.uget((k + 1, k + 1)) }.re();
347            debug_assert_eq!(lower_diag.im(), Zero::zero());
348
349            // Off-diagonal elements, can be complex.
350            let off_diag = match layout {
351                MatrixLayout::C { .. } => match uplo {
352                    UPLO::Upper => unsafe { a.uget((k + 1, k)) },
353                    UPLO::Lower => unsafe { a.uget((k, k + 1)) },
354                },
355                MatrixLayout::F { .. } => match uplo {
356                    UPLO::Upper => unsafe { a.uget((k, k + 1)) },
357                    UPLO::Lower => unsafe { a.uget((k + 1, k)) },
358                },
359            };
360
361            // Determinant of 2x2 block.
362            let block_det = upper_diag * lower_diag - off_diag.square();
363            sign *= block_det.signum();
364            ln_det += Float::ln(Float::abs(block_det));
365
366            // Skip the k+1 ipiv value.
367            ipiv_enum.next();
368        }
369    }
370    (sign, ln_det)
371}
372
373impl<A, S> BKFactorized<S>
374where
375    A: Scalar + Lapack,
376    S: Data<Elem = A>,
377{
378    /// Computes the determinant of the factorized Hermitian (or real
379    /// symmetric) matrix.
380    pub fn deth(&self) -> A::Real {
381        let (sign, ln_det) = self.sln_deth();
382        sign * Float::exp(ln_det)
383    }
384
385    /// Computes the `(sign, natural_log)` of the determinant of the factorized
386    /// Hermitian (or real symmetric) matrix.
387    ///
388    /// The `natural_log` is the natural logarithm of the absolute value of the
389    /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
390    /// is negative infinity.
391    ///
392    /// To obtain the determinant, you can compute `sign * natural_log.exp()`
393    /// or just call `.deth()` instead.
394    ///
395    /// This method is more robust than `.deth()` to very small or very large
396    /// determinants since it returns the natural logarithm of the determinant
397    /// rather than the determinant itself.
398    pub fn sln_deth(&self) -> (A::Real, A::Real) {
399        bk_sln_det(UPLO::Upper, self.ipiv.iter().cloned(), &self.a)
400    }
401
402    /// Computes the determinant of the factorized Hermitian (or real
403    /// symmetric) matrix.
404    pub fn deth_into(self) -> A::Real {
405        let (sign, ln_det) = self.sln_deth_into();
406        sign * Float::exp(ln_det)
407    }
408
409    /// Computes the `(sign, natural_log)` of the determinant of the factorized
410    /// Hermitian (or real symmetric) matrix.
411    ///
412    /// The `natural_log` is the natural logarithm of the absolute value of the
413    /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
414    /// is negative infinity.
415    ///
416    /// To obtain the determinant, you can compute `sign * natural_log.exp()`
417    /// or just call `.deth_into()` instead.
418    ///
419    /// This method is more robust than `.deth_into()` to very small or very
420    /// large determinants since it returns the natural logarithm of the
421    /// determinant rather than the determinant itself.
422    pub fn sln_deth_into(self) -> (A::Real, A::Real) {
423        bk_sln_det(UPLO::Upper, self.ipiv.into_iter(), &self.a)
424    }
425}
426
427impl<A, S> DeterminantH for ArrayBase<S, Ix2>
428where
429    A: Scalar + Lapack,
430    S: Data<Elem = A>,
431{
432    type Elem = A;
433
434    fn deth(&self) -> Result<A::Real> {
435        let (sign, ln_det) = self.sln_deth()?;
436        Ok(sign * Float::exp(ln_det))
437    }
438
439    fn sln_deth(&self) -> Result<(A::Real, A::Real)> {
440        match self.factorizeh() {
441            Ok(fac) => Ok(fac.sln_deth()),
442            Err(LinalgError::Lapack(e))
443                if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
444            {
445                // Determinant is zero.
446                Ok((A::Real::zero(), A::Real::neg_infinity()))
447            }
448            Err(err) => Err(err),
449        }
450    }
451}
452
453impl<A, S> DeterminantHInto for ArrayBase<S, Ix2>
454where
455    A: Scalar + Lapack,
456    S: DataMut<Elem = A>,
457{
458    type Elem = A;
459
460    fn deth_into(self) -> Result<A::Real> {
461        let (sign, ln_det) = self.sln_deth_into()?;
462        Ok(sign * Float::exp(ln_det))
463    }
464
465    fn sln_deth_into(self) -> Result<(A::Real, A::Real)> {
466        match self.factorizeh_into() {
467            Ok(fac) => Ok(fac.sln_deth_into()),
468            Err(LinalgError::Lapack(e))
469                if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
470            {
471                // Determinant is zero.
472                Ok((A::Real::zero(), A::Real::neg_infinity()))
473            }
474            Err(err) => Err(err),
475        }
476    }
477}