ndarray_linalg/solveh.rs
1//! Solve Hermitian (or real symmetric) linear problems and invert Hermitian
2//! (or real symmetric) matrices
3//!
4//! **Note that only the upper triangular portion of the matrix is used.**
5//!
6//! # Examples
7//!
8//! Solve `A * x = b`, where `A` is a Hermitian (or real symmetric) matrix:
9//!
10//! ```
11//! use ndarray::prelude::*;
12//! use ndarray_linalg::SolveH;
13//!
14//! let a: Array2<f64> = array![
15//! [3., 2., -1.],
16//! [2., -2., 4.],
17//! [-1., 4., 5.]
18//! ];
19//! let b: Array1<f64> = array![11., -12., 1.];
20//! let x = a.solveh_into(b).unwrap();
21//! assert!(x.abs_diff_eq(&array![1., 3., -2.], 1e-9));
22//! ```
23//!
24//! If you are solving multiple systems of linear equations with the same
25//! Hermitian or real symmetric coefficient matrix `A`, it's faster to compute
26//! the factorization once at the beginning than solving directly using `A`:
27//!
28//! ```
29//! use ndarray::prelude::*;
30//! use ndarray_linalg::*;
31//!
32//! /// Use fixed algorithm and seed of PRNG for reproducible test
33//! let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
34//!
35//! let a: Array2<f64> = random_using((3, 3), &mut rng);
36//! let f = a.factorizeh_into().unwrap(); // Factorize A (A is consumed)
37//! for _ in 0..10 {
38//! let b: Array1<f64> = random_using(3, &mut rng);
39//! let x = f.solveh_into(b).unwrap(); // Solve A * x = b using the factorization
40//! }
41//! ```
42
43use ndarray::*;
44use num_traits::{Float, One, Zero};
45
46use crate::convert::*;
47use crate::error::*;
48use crate::layout::*;
49use crate::types::*;
50
51pub use lax::{Pivot, UPLO};
52
53/// An interface for solving systems of Hermitian (or real symmetric) linear equations.
54///
55/// If you plan to solve many equations with the same Hermitian (or real
56/// symmetric) coefficient matrix `A` but different `b` vectors, it's faster to
57/// factor the `A` matrix once using the `FactorizeH` trait, and then solve
58/// using the `BKFactorized` struct.
59pub trait SolveH<A: Scalar> {
60 /// Solves a system of linear equations `A * x = b` with Hermitian (or real
61 /// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
62 /// `x` is the successful result.
63 ///
64 /// # Panics
65 ///
66 /// Panics if the length of `b` is not the equal to the number of columns
67 /// of `A`.
68 fn solveh<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
69 let mut b = replicate(b);
70 self.solveh_inplace(&mut b)?;
71 Ok(b)
72 }
73
74 /// Solves a system of linear equations `A * x = b` with Hermitian (or real
75 /// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
76 /// `x` is the successful result.
77 ///
78 /// # Panics
79 ///
80 /// Panics if the length of `b` is not the equal to the number of columns
81 /// of `A`.
82 fn solveh_into<S: DataMut<Elem = A>>(
83 &self,
84 mut b: ArrayBase<S, Ix1>,
85 ) -> Result<ArrayBase<S, Ix1>> {
86 self.solveh_inplace(&mut b)?;
87 Ok(b)
88 }
89
90 /// Solves a system of linear equations `A * x = b` with Hermitian (or real
91 /// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
92 /// `x` is the successful result. The value of `x` is also assigned to the
93 /// argument.
94 ///
95 /// # Panics
96 ///
97 /// Panics if the length of `b` is not the equal to the number of columns
98 /// of `A`.
99 fn solveh_inplace<'a, S: DataMut<Elem = A>>(
100 &self,
101 b: &'a mut ArrayBase<S, Ix1>,
102 ) -> Result<&'a mut ArrayBase<S, Ix1>>;
103}
104
105/// Represents the Bunch–Kaufman factorization of a Hermitian (or real
106/// symmetric) matrix as `A = P * U * D * U^H * P^T`.
107pub struct BKFactorized<S: Data> {
108 pub a: ArrayBase<S, Ix2>,
109 pub ipiv: Pivot,
110}
111
112impl<A, S> SolveH<A> for BKFactorized<S>
113where
114 A: Scalar + Lapack,
115 S: Data<Elem = A>,
116{
117 fn solveh_inplace<'a, Sb>(
118 &self,
119 rhs: &'a mut ArrayBase<Sb, Ix1>,
120 ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
121 where
122 Sb: DataMut<Elem = A>,
123 {
124 assert_eq!(
125 rhs.len(),
126 self.a.len_of(Axis(1)),
127 "The length of `rhs` must be compatible with the shape of the factored matrix.",
128 );
129 A::solveh(
130 self.a.square_layout()?,
131 UPLO::Upper,
132 self.a.as_allocated()?,
133 &self.ipiv,
134 rhs.as_slice_mut().unwrap(),
135 )?;
136 Ok(rhs)
137 }
138}
139
140impl<A, S> SolveH<A> for ArrayBase<S, Ix2>
141where
142 A: Scalar + Lapack,
143 S: Data<Elem = A>,
144{
145 fn solveh_inplace<'a, Sb>(
146 &self,
147 rhs: &'a mut ArrayBase<Sb, Ix1>,
148 ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
149 where
150 Sb: DataMut<Elem = A>,
151 {
152 let f = self.factorizeh()?;
153 f.solveh_inplace(rhs)
154 }
155}
156
157/// An interface for computing the Bunch–Kaufman factorization of Hermitian (or
158/// real symmetric) matrix refs.
159pub trait FactorizeH<S: Data> {
160 /// Computes the Bunch–Kaufman factorization of a Hermitian (or real
161 /// symmetric) matrix.
162 fn factorizeh(&self) -> Result<BKFactorized<S>>;
163}
164
165/// An interface for computing the Bunch–Kaufman factorization of Hermitian (or
166/// real symmetric) matrices.
167pub trait FactorizeHInto<S: Data> {
168 /// Computes the Bunch–Kaufman factorization of a Hermitian (or real
169 /// symmetric) matrix.
170 fn factorizeh_into(self) -> Result<BKFactorized<S>>;
171}
172
173impl<A, S> FactorizeHInto<S> for ArrayBase<S, Ix2>
174where
175 A: Scalar + Lapack,
176 S: DataMut<Elem = A>,
177{
178 fn factorizeh_into(mut self) -> Result<BKFactorized<S>> {
179 let ipiv = A::bk(self.square_layout()?, UPLO::Upper, self.as_allocated_mut()?)?;
180 Ok(BKFactorized { a: self, ipiv })
181 }
182}
183
184impl<A, Si> FactorizeH<OwnedRepr<A>> for ArrayBase<Si, Ix2>
185where
186 A: Scalar + Lapack,
187 Si: Data<Elem = A>,
188{
189 fn factorizeh(&self) -> Result<BKFactorized<OwnedRepr<A>>> {
190 let mut a: Array2<A> = replicate(self);
191 let ipiv = A::bk(a.square_layout()?, UPLO::Upper, a.as_allocated_mut()?)?;
192 Ok(BKFactorized { a, ipiv })
193 }
194}
195
196/// An interface for inverting Hermitian (or real symmetric) matrix refs.
197pub trait InverseH {
198 type Output;
199 /// Computes the inverse of the Hermitian (or real symmetric) matrix.
200 fn invh(&self) -> Result<Self::Output>;
201}
202
203/// An interface for inverting Hermitian (or real symmetric) matrices.
204pub trait InverseHInto {
205 type Output;
206 /// Computes the inverse of the Hermitian (or real symmetric) matrix.
207 fn invh_into(self) -> Result<Self::Output>;
208}
209
210impl<A, S> InverseHInto for BKFactorized<S>
211where
212 A: Scalar + Lapack,
213 S: DataMut<Elem = A>,
214{
215 type Output = ArrayBase<S, Ix2>;
216
217 fn invh_into(mut self) -> Result<ArrayBase<S, Ix2>> {
218 A::invh(
219 self.a.square_layout()?,
220 UPLO::Upper,
221 self.a.as_allocated_mut()?,
222 &self.ipiv,
223 )?;
224 triangular_fill_hermitian(&mut self.a, UPLO::Upper);
225 Ok(self.a)
226 }
227}
228
229impl<A, S> InverseH for BKFactorized<S>
230where
231 A: Scalar + Lapack,
232 S: Data<Elem = A>,
233{
234 type Output = Array2<A>;
235
236 fn invh(&self) -> Result<Self::Output> {
237 let f = BKFactorized {
238 a: replicate(&self.a),
239 ipiv: self.ipiv.clone(),
240 };
241 f.invh_into()
242 }
243}
244
245impl<A, S> InverseHInto for ArrayBase<S, Ix2>
246where
247 A: Scalar + Lapack,
248 S: DataMut<Elem = A>,
249{
250 type Output = Self;
251
252 fn invh_into(self) -> Result<Self::Output> {
253 let f = self.factorizeh_into()?;
254 f.invh_into()
255 }
256}
257
258impl<A, Si> InverseH for ArrayBase<Si, Ix2>
259where
260 A: Scalar + Lapack,
261 Si: Data<Elem = A>,
262{
263 type Output = Array2<A>;
264
265 fn invh(&self) -> Result<Self::Output> {
266 let f = self.factorizeh()?;
267 f.invh_into()
268 }
269}
270
271/// An interface for calculating determinants of Hermitian (or real symmetric) matrix refs.
272pub trait DeterminantH {
273 /// The element type of the matrix.
274 type Elem: Scalar;
275
276 /// Computes the determinant of the Hermitian (or real symmetric) matrix.
277 fn deth(&self) -> Result<<Self::Elem as Scalar>::Real>;
278
279 /// Computes the `(sign, natural_log)` of the determinant of the Hermitian
280 /// (or real symmetric) matrix.
281 ///
282 /// The `natural_log` is the natural logarithm of the absolute value of the
283 /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
284 /// is negative infinity.
285 ///
286 /// To obtain the determinant, you can compute `sign * natural_log.exp()`
287 /// or just call `.deth()` instead.
288 ///
289 /// This method is more robust than `.deth()` to very small or very large
290 /// determinants since it returns the natural logarithm of the determinant
291 /// rather than the determinant itself.
292 fn sln_deth(&self) -> Result<(<Self::Elem as Scalar>::Real, <Self::Elem as Scalar>::Real)>;
293}
294
295/// An interface for calculating determinants of Hermitian (or real symmetric) matrices.
296pub trait DeterminantHInto {
297 /// The element type of the matrix.
298 type Elem: Scalar;
299
300 /// Computes the determinant of the Hermitian (or real symmetric) matrix.
301 fn deth_into(self) -> Result<<Self::Elem as Scalar>::Real>;
302
303 /// Computes the `(sign, natural_log)` of the determinant of the Hermitian
304 /// (or real symmetric) matrix.
305 ///
306 /// The `natural_log` is the natural logarithm of the absolute value of the
307 /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
308 /// is negative infinity.
309 ///
310 /// To obtain the determinant, you can compute `sign * natural_log.exp()`
311 /// or just call `.deth_into()` instead.
312 ///
313 /// This method is more robust than `.deth_into()` to very small or very
314 /// large determinants since it returns the natural logarithm of the
315 /// determinant rather than the determinant itself.
316 fn sln_deth_into(self) -> Result<(<Self::Elem as Scalar>::Real, <Self::Elem as Scalar>::Real)>;
317}
318
319/// Returns the sign and natural log of the determinant.
320fn bk_sln_det<P, S, A>(uplo: UPLO, ipiv_iter: P, a: &ArrayBase<S, Ix2>) -> (A::Real, A::Real)
321where
322 P: Iterator<Item = i32>,
323 S: Data<Elem = A>,
324 A: Scalar + Lapack,
325{
326 let layout = a.layout().unwrap();
327 let mut sign = A::Real::one();
328 let mut ln_det = A::Real::zero();
329 let mut ipiv_enum = ipiv_iter.enumerate();
330 while let Some((k, ipiv_k)) = ipiv_enum.next() {
331 debug_assert!(k < a.nrows() && k < a.ncols());
332 if ipiv_k > 0 {
333 // 1x1 block at k, must be real.
334 let elem = unsafe { a.uget((k, k)) }.re();
335 debug_assert_eq!(elem.im(), Zero::zero());
336 sign *= elem.signum();
337 ln_det += Float::ln(Float::abs(elem));
338 } else {
339 // 2x2 block at k..k+2.
340
341 // Upper left diagonal elem, must be real.
342 let upper_diag = unsafe { a.uget((k, k)) }.re();
343 debug_assert_eq!(upper_diag.im(), Zero::zero());
344
345 // Lower right diagonal elem, must be real.
346 let lower_diag = unsafe { a.uget((k + 1, k + 1)) }.re();
347 debug_assert_eq!(lower_diag.im(), Zero::zero());
348
349 // Off-diagonal elements, can be complex.
350 let off_diag = match layout {
351 MatrixLayout::C { .. } => match uplo {
352 UPLO::Upper => unsafe { a.uget((k + 1, k)) },
353 UPLO::Lower => unsafe { a.uget((k, k + 1)) },
354 },
355 MatrixLayout::F { .. } => match uplo {
356 UPLO::Upper => unsafe { a.uget((k, k + 1)) },
357 UPLO::Lower => unsafe { a.uget((k + 1, k)) },
358 },
359 };
360
361 // Determinant of 2x2 block.
362 let block_det = upper_diag * lower_diag - off_diag.square();
363 sign *= block_det.signum();
364 ln_det += Float::ln(Float::abs(block_det));
365
366 // Skip the k+1 ipiv value.
367 ipiv_enum.next();
368 }
369 }
370 (sign, ln_det)
371}
372
373impl<A, S> BKFactorized<S>
374where
375 A: Scalar + Lapack,
376 S: Data<Elem = A>,
377{
378 /// Computes the determinant of the factorized Hermitian (or real
379 /// symmetric) matrix.
380 pub fn deth(&self) -> A::Real {
381 let (sign, ln_det) = self.sln_deth();
382 sign * Float::exp(ln_det)
383 }
384
385 /// Computes the `(sign, natural_log)` of the determinant of the factorized
386 /// Hermitian (or real symmetric) matrix.
387 ///
388 /// The `natural_log` is the natural logarithm of the absolute value of the
389 /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
390 /// is negative infinity.
391 ///
392 /// To obtain the determinant, you can compute `sign * natural_log.exp()`
393 /// or just call `.deth()` instead.
394 ///
395 /// This method is more robust than `.deth()` to very small or very large
396 /// determinants since it returns the natural logarithm of the determinant
397 /// rather than the determinant itself.
398 pub fn sln_deth(&self) -> (A::Real, A::Real) {
399 bk_sln_det(UPLO::Upper, self.ipiv.iter().cloned(), &self.a)
400 }
401
402 /// Computes the determinant of the factorized Hermitian (or real
403 /// symmetric) matrix.
404 pub fn deth_into(self) -> A::Real {
405 let (sign, ln_det) = self.sln_deth_into();
406 sign * Float::exp(ln_det)
407 }
408
409 /// Computes the `(sign, natural_log)` of the determinant of the factorized
410 /// Hermitian (or real symmetric) matrix.
411 ///
412 /// The `natural_log` is the natural logarithm of the absolute value of the
413 /// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
414 /// is negative infinity.
415 ///
416 /// To obtain the determinant, you can compute `sign * natural_log.exp()`
417 /// or just call `.deth_into()` instead.
418 ///
419 /// This method is more robust than `.deth_into()` to very small or very
420 /// large determinants since it returns the natural logarithm of the
421 /// determinant rather than the determinant itself.
422 pub fn sln_deth_into(self) -> (A::Real, A::Real) {
423 bk_sln_det(UPLO::Upper, self.ipiv.into_iter(), &self.a)
424 }
425}
426
427impl<A, S> DeterminantH for ArrayBase<S, Ix2>
428where
429 A: Scalar + Lapack,
430 S: Data<Elem = A>,
431{
432 type Elem = A;
433
434 fn deth(&self) -> Result<A::Real> {
435 let (sign, ln_det) = self.sln_deth()?;
436 Ok(sign * Float::exp(ln_det))
437 }
438
439 fn sln_deth(&self) -> Result<(A::Real, A::Real)> {
440 match self.factorizeh() {
441 Ok(fac) => Ok(fac.sln_deth()),
442 Err(LinalgError::Lapack(e))
443 if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
444 {
445 // Determinant is zero.
446 Ok((A::Real::zero(), A::Real::neg_infinity()))
447 }
448 Err(err) => Err(err),
449 }
450 }
451}
452
453impl<A, S> DeterminantHInto for ArrayBase<S, Ix2>
454where
455 A: Scalar + Lapack,
456 S: DataMut<Elem = A>,
457{
458 type Elem = A;
459
460 fn deth_into(self) -> Result<A::Real> {
461 let (sign, ln_det) = self.sln_deth_into()?;
462 Ok(sign * Float::exp(ln_det))
463 }
464
465 fn sln_deth_into(self) -> Result<(A::Real, A::Real)> {
466 match self.factorizeh_into() {
467 Ok(fac) => Ok(fac.sln_deth_into()),
468 Err(LinalgError::Lapack(e))
469 if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
470 {
471 // Determinant is zero.
472 Ok((A::Real::zero(), A::Real::neg_infinity()))
473 }
474 Err(err) => Err(err),
475 }
476 }
477}