ndarray_linalg/solveh.rs
1//! Solve Hermitian (or real symmetric) linear problems and invert Hermitian
2//! (or real symmetric) matrices
3//!
4//! **Note that only the upper triangular portion of the matrix is used.**
5//!
6//! # Examples
7//!
8//! Solve `A * x = b`, where `A` is a Hermitian (or real symmetric) matrix:
9//!
10//! ```
11//! use ndarray::prelude::*;
12//! use ndarray_linalg::SolveH;
13//!
14//! let a: Array2<f64> = array![
15//! [3., 2., -1.],
16//! [2., -2., 4.],
17//! [-1., 4., 5.]
18//! ];
19//! let b: Array1<f64> = array![11., -12., 1.];
20//! let x = a.solveh_into(b).unwrap();
21//! assert!(x.abs_diff_eq(&array![1., 3., -2.], 1e-9));
22//! ```
23//!
24//! If you are solving multiple systems of linear equations with the same
25//! Hermitian or real symmetric coefficient matrix `A`, it's faster to compute
26//! the factorization once at the beginning than solving directly using `A`:
27//!
28//! ```
29//! use ndarray::prelude::*;
30//! use ndarray_linalg::*;
31//!
32//! /// Use fixed algorithm and seed of PRNG for reproducible test
33//! let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
34//!
35//! let a: Array2<f64> = random_using((3, 3), &mut rng);
36//! let f = a.factorizeh_into().unwrap(); // Factorize A (A is consumed)
37//! for _ in 0..10 {
38//! let b: Array1<f64> = random_using(3, &mut rng);
39//! let x = f.solveh_into(b).unwrap(); // Solve A * x = b using the factorization
40//! }
41//! ```
42
43use ndarray::*;
44use num_traits::{Float, One, Zero};
45
46use crate::convert::*;
47use crate::error::*;
48use crate::layout::*;
49use crate::types::*;
50
51pub use lax::{Pivot, UPLO};
52
53/// An interface for solving systems of Hermitian (or real symmetric) linear equations.
54///
55/// If you plan to solve many equations with the same Hermitian (or real
56/// symmetric) coefficient matrix `A` but different `b` vectors, it's faster to
57/// factor the `A` matrix once using the `FactorizeH` trait, and then solve
58/// using the `BKFactorized` struct.
59pub trait SolveH<A: Scalar> {
60 /// Solves a system of linear equations `A * x = b` with Hermitian (or real
61 /// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
62 /// `x` is the successful result.
63 ///
64 /// # Panics
65 ///
66 /// Panics if the length of `b` is not the equal to the number of columns
67 /// of `A`.
68 fn solveh(&self, b: &ArrayRef<A, Ix1>) -> Result<Array1<A>> {
69 let mut b = replicate(b);
70 self.solveh_inplace(&mut b)?;
71 Ok(b)
72 }
73
74 /// Solves a system of linear equations `A * x = b` with Hermitian (or real
75 /// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
76 /// `x` is the successful result.
77 ///
78 /// # Panics
79 ///
80 /// Panics if the length of `b` is not the equal to the number of columns
81 /// of `A`.
82 fn solveh_into<S: DataMut<Elem = A>>(
83 &self,
84 mut b: ArrayBase<S, Ix1>,
85 ) -> Result<ArrayBase<S, Ix1>> {
86 self.solveh_inplace(&mut b)?;
87 Ok(b)
88 }
89
90 /// Solves a system of linear equations `A * x = b` with Hermitian (or real
91 /// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
92 /// `x` is the successful result. The value of `x` is also assigned to the
93 /// argument.
94 ///
95 /// # Panics
96 ///
97 /// Panics if the length of `b` is not the equal to the number of columns
98 /// of `A`.
99 fn solveh_inplace<'a>(&self, b: &'a mut ArrayRef<A, Ix1>) -> Result<&'a mut ArrayRef<A, Ix1>>;
100}
101
102/// Represents the Bunch–Kaufman factorization of a Hermitian (or real
103/// symmetric) matrix as `A = P * U * D * U^H * P^T`.
104pub struct BKFactorized<S: Data> {
105 pub a: ArrayBase<S, Ix2>,
106 pub ipiv: Pivot,
107}
108
109impl<A, S> SolveH<A> for BKFactorized<S>
110where
111 A: Scalar + Lapack,
112 S: Data<Elem = A>,
113{
114 fn solveh_inplace<'a>(
115 &self,
116 rhs: &'a mut ArrayRef<A, Ix1>,
117 ) -> Result<&'a mut ArrayRef<A, Ix1>> {
118 assert_eq!(
119 rhs.len(),
120 self.a.len_of(Axis(1)),
121 "The length of `rhs` must be compatible with the shape of the factored matrix.",
122 );
123 A::solveh(
124 self.a.square_layout()?,
125 UPLO::Upper,
126 self.a.as_allocated()?,
127 &self.ipiv,
128 rhs.as_slice_mut().unwrap(),
129 )?;
130 Ok(rhs)
131 }
132}
133
134impl<A> SolveH<A> for ArrayRef<A, Ix2>
135where
136 A: Scalar + Lapack,
137{
138 fn solveh_inplace<'a>(
139 &self,
140 rhs: &'a mut ArrayRef<A, Ix1>,
141 ) -> Result<&'a mut ArrayRef<A, Ix1>> {
142 let f = self.factorizeh()?;
143 f.solveh_inplace(rhs)
144 }
145}
146
147/// An interface for computing the Bunch–Kaufman factorization of Hermitian (or
148/// real symmetric) matrix refs.
149pub trait FactorizeH<S: Data> {
150 /// Computes the Bunch–Kaufman factorization of a Hermitian (or real
151 /// symmetric) matrix.
152 fn factorizeh(&self) -> Result<BKFactorized<S>>;
153}
154
155/// An interface for computing the Bunch–Kaufman factorization of Hermitian (or
156/// real symmetric) matrices.
157pub trait FactorizeHInto<S: Data> {
158 /// Computes the Bunch–Kaufman factorization of a Hermitian (or real
159 /// symmetric) matrix.
160 fn factorizeh_into(self) -> Result<BKFactorized<S>>;
161}
162
163impl<A, S> FactorizeHInto<S> for ArrayBase<S, Ix2>
164where
165 A: Scalar + Lapack,
166 S: DataMut<Elem = A>,
167{
168 fn factorizeh_into(mut self) -> Result<BKFactorized<S>> {
169 let ipiv = A::bk(self.square_layout()?, UPLO::Upper, self.as_allocated_mut()?)?;
170 Ok(BKFactorized { a: self, ipiv })
171 }
172}
173
174impl<A> FactorizeH<OwnedRepr<A>> for ArrayRef<A, Ix2>
175where
176 A: Scalar + Lapack,
177{
178 fn factorizeh(&self) -> Result<BKFactorized<OwnedRepr<A>>> {
179 let mut a: Array2<A> = replicate(self);
180 let ipiv = A::bk(a.square_layout()?, UPLO::Upper, a.as_allocated_mut()?)?;
181 Ok(BKFactorized { a, ipiv })
182 }
183}
184
185/// An interface for inverting Hermitian (or real symmetric) matrix refs.
186pub trait InverseH {
187 type Output;
188 /// Computes the inverse of the Hermitian (or real symmetric) matrix.
189 fn invh(&self) -> Result<Self::Output>;
190}
191
192/// An interface for inverting Hermitian (or real symmetric) matrices.
193pub trait InverseHInto {
194 type Output;
195 /// Computes the inverse of the Hermitian (or real symmetric) matrix.
196 fn invh_into(self) -> Result<Self::Output>;
197}
198
199impl<A, S> InverseHInto for BKFactorized<S>
200where
201 A: Scalar + Lapack,
202 S: DataMut<Elem = A>,
203{
204 type Output = ArrayBase<S, Ix2>;
205
206 fn invh_into(mut self) -> Result<ArrayBase<S, Ix2>> {
207 A::invh(
208 self.a.square_layout()?,
209 UPLO::Upper,
210 self.a.as_allocated_mut()?,
211 &self.ipiv,
212 )?;
213 triangular_fill_hermitian(&mut self.a, UPLO::Upper);
214 Ok(self.a)
215 }
216}
217
218impl<A, S> InverseH for BKFactorized<S>
219where
220 A: Scalar + Lapack,
221 S: Data<Elem = A>,
222{
223 type Output = Array2<A>;
224
225 fn invh(&self) -> Result<Self::Output> {
226 let f = BKFactorized {
227 a: replicate(&self.a),
228 ipiv: self.ipiv.clone(),
229 };
230 f.invh_into()
231 }
232}
233
234impl<A, S> InverseHInto for ArrayBase<S, Ix2>
235where
236 A: Scalar + Lapack,
237 S: DataMut<Elem = A>,
238{
239 type Output = Self;
240
241 fn invh_into(self) -> Result<Self::Output> {
242 let f = self.factorizeh_into()?;
243 f.invh_into()
244 }
245}
246
247impl<A> InverseH for ArrayRef<A, Ix2>
248where
249 A: Scalar + Lapack,
250{
251 type Output = Array2<A>;
252
253 fn invh(&self) -> Result<Self::Output> {
254 let f = self.factorizeh()?;
255 f.invh_into()
256 }
257}
258
259/// An interface for calculating determinants of Hermitian (or real symmetric) matrix refs.
260pub trait DeterminantH {
261 /// The element type of the matrix.
262 type Elem: Scalar;
263
264 /// Computes the determinant of the Hermitian (or real symmetric) matrix.
265 fn deth(&self) -> Result<<Self::Elem as Scalar>::Real>;
266
267 /// Computes the `(sign, natural_log)` of the determinant of the Hermitian
268 /// (or real symmetric) matrix.
269 ///
270 /// The `natural_log` is the natural logarithm of the absolute value of the
271 /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
272 /// is negative infinity.
273 ///
274 /// To obtain the determinant, you can compute `sign * natural_log.exp()`
275 /// or just call `.deth()` instead.
276 ///
277 /// This method is more robust than `.deth()` to very small or very large
278 /// determinants since it returns the natural logarithm of the determinant
279 /// rather than the determinant itself.
280 fn sln_deth(&self) -> Result<(<Self::Elem as Scalar>::Real, <Self::Elem as Scalar>::Real)>;
281}
282
283/// An interface for calculating determinants of Hermitian (or real symmetric) matrices.
284pub trait DeterminantHInto {
285 /// The element type of the matrix.
286 type Elem: Scalar;
287
288 /// Computes the determinant of the Hermitian (or real symmetric) matrix.
289 fn deth_into(self) -> Result<<Self::Elem as Scalar>::Real>;
290
291 /// Computes the `(sign, natural_log)` of the determinant of the Hermitian
292 /// (or real symmetric) matrix.
293 ///
294 /// The `natural_log` is the natural logarithm of the absolute value of the
295 /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
296 /// is negative infinity.
297 ///
298 /// To obtain the determinant, you can compute `sign * natural_log.exp()`
299 /// or just call `.deth_into()` instead.
300 ///
301 /// This method is more robust than `.deth_into()` to very small or very
302 /// large determinants since it returns the natural logarithm of the
303 /// determinant rather than the determinant itself.
304 fn sln_deth_into(self) -> Result<(<Self::Elem as Scalar>::Real, <Self::Elem as Scalar>::Real)>;
305}
306
307/// Returns the sign and natural log of the determinant.
308fn bk_sln_det<P, A>(uplo: UPLO, ipiv_iter: P, a: &ArrayRef<A, Ix2>) -> (A::Real, A::Real)
309where
310 P: Iterator<Item = i32>,
311 A: Scalar + Lapack,
312{
313 let layout = a.layout().unwrap();
314 let mut sign = A::Real::one();
315 let mut ln_det = A::Real::zero();
316 let mut ipiv_enum = ipiv_iter.enumerate();
317 while let Some((k, ipiv_k)) = ipiv_enum.next() {
318 debug_assert!(k < a.nrows() && k < a.ncols());
319 if ipiv_k > 0 {
320 // 1x1 block at k, must be real.
321 let elem = unsafe { a.uget((k, k)) }.re();
322 debug_assert_eq!(elem.im(), Zero::zero());
323 sign *= elem.signum();
324 ln_det += Float::ln(Float::abs(elem));
325 } else {
326 // 2x2 block at k..k+2.
327
328 // Upper left diagonal elem, must be real.
329 let upper_diag = unsafe { a.uget((k, k)) }.re();
330 debug_assert_eq!(upper_diag.im(), Zero::zero());
331
332 // Lower right diagonal elem, must be real.
333 let lower_diag = unsafe { a.uget((k + 1, k + 1)) }.re();
334 debug_assert_eq!(lower_diag.im(), Zero::zero());
335
336 // Off-diagonal elements, can be complex.
337 let off_diag = match layout {
338 MatrixLayout::C { .. } => match uplo {
339 UPLO::Upper => unsafe { a.uget((k + 1, k)) },
340 UPLO::Lower => unsafe { a.uget((k, k + 1)) },
341 },
342 MatrixLayout::F { .. } => match uplo {
343 UPLO::Upper => unsafe { a.uget((k, k + 1)) },
344 UPLO::Lower => unsafe { a.uget((k + 1, k)) },
345 },
346 };
347
348 // Determinant of 2x2 block.
349 let block_det = upper_diag * lower_diag - off_diag.square();
350 sign *= block_det.signum();
351 ln_det += Float::ln(Float::abs(block_det));
352
353 // Skip the k+1 ipiv value.
354 ipiv_enum.next();
355 }
356 }
357 (sign, ln_det)
358}
359
360impl<A, S> BKFactorized<S>
361where
362 A: Scalar + Lapack,
363 S: Data<Elem = A>,
364{
365 /// Computes the determinant of the factorized Hermitian (or real
366 /// symmetric) matrix.
367 pub fn deth(&self) -> A::Real {
368 let (sign, ln_det) = self.sln_deth();
369 sign * Float::exp(ln_det)
370 }
371
372 /// Computes the `(sign, natural_log)` of the determinant of the factorized
373 /// Hermitian (or real symmetric) matrix.
374 ///
375 /// The `natural_log` is the natural logarithm of the absolute value of the
376 /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
377 /// is negative infinity.
378 ///
379 /// To obtain the determinant, you can compute `sign * natural_log.exp()`
380 /// or just call `.deth()` instead.
381 ///
382 /// This method is more robust than `.deth()` to very small or very large
383 /// determinants since it returns the natural logarithm of the determinant
384 /// rather than the determinant itself.
385 pub fn sln_deth(&self) -> (A::Real, A::Real) {
386 bk_sln_det(UPLO::Upper, self.ipiv.iter().cloned(), &self.a)
387 }
388
389 /// Computes the determinant of the factorized Hermitian (or real
390 /// symmetric) matrix.
391 pub fn deth_into(self) -> A::Real {
392 let (sign, ln_det) = self.sln_deth_into();
393 sign * Float::exp(ln_det)
394 }
395
396 /// Computes the `(sign, natural_log)` of the determinant of the factorized
397 /// Hermitian (or real symmetric) matrix.
398 ///
399 /// The `natural_log` is the natural logarithm of the absolute value of the
400 /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
401 /// is negative infinity.
402 ///
403 /// To obtain the determinant, you can compute `sign * natural_log.exp()`
404 /// or just call `.deth_into()` instead.
405 ///
406 /// This method is more robust than `.deth_into()` to very small or very
407 /// large determinants since it returns the natural logarithm of the
408 /// determinant rather than the determinant itself.
409 pub fn sln_deth_into(self) -> (A::Real, A::Real) {
410 bk_sln_det(UPLO::Upper, self.ipiv.into_iter(), &self.a)
411 }
412}
413
414impl<A> DeterminantH for ArrayRef<A, Ix2>
415where
416 A: Scalar + Lapack,
417{
418 type Elem = A;
419
420 fn deth(&self) -> Result<A::Real> {
421 let (sign, ln_det) = self.sln_deth()?;
422 Ok(sign * Float::exp(ln_det))
423 }
424
425 fn sln_deth(&self) -> Result<(A::Real, A::Real)> {
426 match self.factorizeh() {
427 Ok(fac) => Ok(fac.sln_deth()),
428 Err(LinalgError::Lapack(e))
429 if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
430 {
431 // Determinant is zero.
432 Ok((A::Real::zero(), A::Real::neg_infinity()))
433 }
434 Err(err) => Err(err),
435 }
436 }
437}
438
439impl<A, S> DeterminantHInto for ArrayBase<S, Ix2>
440where
441 A: Scalar + Lapack,
442 S: DataMut<Elem = A>,
443{
444 type Elem = A;
445
446 fn deth_into(self) -> Result<A::Real> {
447 let (sign, ln_det) = self.sln_deth_into()?;
448 Ok(sign * Float::exp(ln_det))
449 }
450
451 fn sln_deth_into(self) -> Result<(A::Real, A::Real)> {
452 match self.factorizeh_into() {
453 Ok(fac) => Ok(fac.sln_deth_into()),
454 Err(LinalgError::Lapack(e))
455 if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
456 {
457 // Determinant is zero.
458 Ok((A::Real::zero(), A::Real::neg_infinity()))
459 }
460 Err(err) => Err(err),
461 }
462 }
463}