ndarray_linalg/cholesky.rs
1//! Cholesky decomposition of Hermitian (or real symmetric) positive definite matrices
2//!
3//! See the [Wikipedia page about Cholesky
4//! decomposition](https://en.wikipedia.org/wiki/Cholesky_decomposition) for
5//! more information.
6//!
7//! # Example
8//!
9//! Using the Cholesky decomposition of `A` for various operations, where `A`
10//! is a Hermitian (or real symmetric) positive definite matrix:
11//!
12//! ```
13//! #[macro_use]
14//! extern crate ndarray;
15//! extern crate ndarray_linalg;
16//!
17//! use ndarray::prelude::*;
18//! use ndarray_linalg::cholesky::*;
19//! # fn main() {
20//!
21//! let a: Array2<f64> = array![
22//! [ 4., 12., -16.],
23//! [ 12., 37., -43.],
24//! [-16., -43., 98.]
25//! ];
26//!
27//! // Obtain `L`
28//! let lower = a.cholesky(UPLO::Lower).unwrap();
29//! assert!(lower.abs_diff_eq(&array![
30//! [ 2., 0., 0.],
31//! [ 6., 1., 0.],
32//! [-8., 5., 3.]
33//! ], 1e-9));
34//!
35//! // Find the determinant of `A`
36//! let det = a.detc().unwrap();
37//! assert!((det - 36.).abs() < 1e-9);
38//!
39//! // Solve `A * x = b`
40//! let b = array![4., 13., -11.];
41//! let x = a.solvec(&b).unwrap();
42//! assert!(x.abs_diff_eq(&array![-2., 1., 0.], 1e-9));
43//! # }
44//! ```
45
46use ndarray::*;
47use num_traits::Float;
48
49use crate::convert::*;
50use crate::error::*;
51use crate::layout::*;
52use crate::triangular::IntoTriangular;
53use crate::types::*;
54
55pub use lax::UPLO;
56
57/// Cholesky decomposition of Hermitian (or real symmetric) positive definite matrix
58pub struct CholeskyFactorized<S: Data> {
59 /// `L` from the decomposition `A = L * L^H` or `U` from the decomposition
60 /// `A = U^H * U`.
61 pub factor: ArrayBase<S, Ix2>,
62 /// If this is `UPLO::Lower`, then `self.factor` is `L`. If this is
63 /// `UPLO::Upper`, then `self.factor` is `U`.
64 pub uplo: UPLO,
65}
66
67impl<A, S> CholeskyFactorized<S>
68where
69 A: Scalar + Lapack,
70 S: DataMut<Elem = A>,
71{
72 /// Returns `L` from the Cholesky decomposition `A = L * L^H`.
73 ///
74 /// If `self.uplo == UPLO::Lower`, then no computations need to be
75 /// performed; otherwise, the conjugate transpose of `self.factor` is
76 /// calculated.
77 pub fn into_lower(self) -> ArrayBase<S, Ix2> {
78 match self.uplo {
79 UPLO::Lower => self.factor,
80 UPLO::Upper => self.factor.reversed_axes().mapv_into(|elem| elem.conj()),
81 }
82 }
83
84 /// Returns `U` from the Cholesky decomposition `A = U^H * U`.
85 ///
86 /// If `self.uplo == UPLO::Upper`, then no computations need to be
87 /// performed; otherwise, the conjugate transpose of `self.factor` is
88 /// calculated.
89 pub fn into_upper(self) -> ArrayBase<S, Ix2> {
90 match self.uplo {
91 UPLO::Lower => self.factor.reversed_axes().mapv_into(|elem| elem.conj()),
92 UPLO::Upper => self.factor,
93 }
94 }
95}
96
97impl<A, S> DeterminantC for CholeskyFactorized<S>
98where
99 A: Scalar + Lapack,
100 S: Data<Elem = A>,
101{
102 type Output = <A as Scalar>::Real;
103
104 fn detc(&self) -> Self::Output {
105 Float::exp(self.ln_detc())
106 }
107
108 fn ln_detc(&self) -> Self::Output {
109 self.factor
110 .diag()
111 .iter()
112 .map(|elem| Float::ln(elem.square()))
113 .sum::<Self::Output>()
114 }
115}
116
117impl<A, S> DeterminantCInto for CholeskyFactorized<S>
118where
119 A: Scalar + Lapack,
120 S: Data<Elem = A>,
121{
122 type Output = <A as Scalar>::Real;
123
124 fn detc_into(self) -> Self::Output {
125 self.detc()
126 }
127
128 fn ln_detc_into(self) -> Self::Output {
129 self.ln_detc()
130 }
131}
132
133impl<A, S> InverseC for CholeskyFactorized<S>
134where
135 A: Scalar + Lapack,
136 S: Data<Elem = A>,
137{
138 type Output = Array2<A>;
139
140 fn invc(&self) -> Result<Self::Output> {
141 let f = CholeskyFactorized {
142 factor: replicate(&self.factor),
143 uplo: self.uplo,
144 };
145 f.invc_into()
146 }
147}
148
149impl<A, S> InverseCInto for CholeskyFactorized<S>
150where
151 A: Scalar + Lapack,
152 S: DataMut<Elem = A>,
153{
154 type Output = ArrayBase<S, Ix2>;
155
156 fn invc_into(self) -> Result<Self::Output> {
157 let mut a = self.factor;
158 A::inv_cholesky(a.square_layout()?, self.uplo, a.as_allocated_mut()?)?;
159 triangular_fill_hermitian(&mut a, self.uplo);
160 Ok(a)
161 }
162}
163
164impl<A, S> SolveC<A> for CholeskyFactorized<S>
165where
166 A: Scalar + Lapack,
167 S: Data<Elem = A>,
168{
169 fn solvec_inplace<'a, Sb>(
170 &self,
171 b: &'a mut ArrayBase<Sb, Ix1>,
172 ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
173 where
174 Sb: DataMut<Elem = A>,
175 {
176 A::solve_cholesky(
177 self.factor.square_layout()?,
178 self.uplo,
179 self.factor.as_allocated()?,
180 b.as_slice_mut().unwrap(),
181 )?;
182 Ok(b)
183 }
184}
185
186/// Cholesky decomposition of Hermitian (or real symmetric) positive definite matrix reference
187pub trait Cholesky {
188 type Output;
189
190 /// Computes the Cholesky decomposition of the Hermitian (or real
191 /// symmetric) positive definite matrix.
192 ///
193 /// If the argument is `UPLO::Upper`, then computes the decomposition `A =
194 /// U^H * U` using the upper triangular portion of `A` and returns `U`.
195 /// Otherwise, if the argument is `UPLO::Lower`, computes the decomposition
196 /// `A = L * L^H` using the lower triangular portion of `A` and returns
197 /// `L`.
198 fn cholesky(&self, uplo: UPLO) -> Result<Self::Output>;
199}
200
201/// Cholesky decomposition of Hermitian (or real symmetric) positive definite matrix
202pub trait CholeskyInto {
203 type Output;
204 /// Computes the Cholesky decomposition of the Hermitian (or real
205 /// symmetric) positive definite matrix.
206 ///
207 /// If the argument is `UPLO::Upper`, then computes the decomposition `A =
208 /// U^H * U` using the upper triangular portion of `A` and returns `U`.
209 /// Otherwise, if the argument is `UPLO::Lower`, computes the decomposition
210 /// `A = L * L^H` using the lower triangular portion of `A` and returns
211 /// `L`.
212 fn cholesky_into(self, uplo: UPLO) -> Result<Self::Output>;
213}
214
215/// Cholesky decomposition of Hermitian (or real symmetric) positive definite mutable reference of matrix
216pub trait CholeskyInplace {
217 /// Computes the Cholesky decomposition of the Hermitian (or real
218 /// symmetric) positive definite matrix, writing the result (`L` or `U`
219 /// according to the argument) to `self` and returning it.
220 ///
221 /// If the argument is `UPLO::Upper`, then computes the decomposition `A =
222 /// U^H * U` using the upper triangular portion of `A` and writes `U`.
223 /// Otherwise, if the argument is `UPLO::Lower`, computes the decomposition
224 /// `A = L * L^H` using the lower triangular portion of `A` and writes `L`.
225 fn cholesky_inplace(&mut self, uplo: UPLO) -> Result<&mut Self>;
226}
227
228impl<A, S> Cholesky for ArrayBase<S, Ix2>
229where
230 A: Scalar + Lapack,
231 S: Data<Elem = A>,
232{
233 type Output = Array2<A>;
234
235 fn cholesky(&self, uplo: UPLO) -> Result<Array2<A>> {
236 let a = replicate(self);
237 a.cholesky_into(uplo)
238 }
239}
240
241impl<A, S> CholeskyInto for ArrayBase<S, Ix2>
242where
243 A: Scalar + Lapack,
244 S: DataMut<Elem = A>,
245{
246 type Output = Self;
247
248 fn cholesky_into(mut self, uplo: UPLO) -> Result<Self> {
249 self.cholesky_inplace(uplo)?;
250 Ok(self)
251 }
252}
253
254impl<A, S> CholeskyInplace for ArrayBase<S, Ix2>
255where
256 A: Scalar + Lapack,
257 S: DataMut<Elem = A>,
258{
259 fn cholesky_inplace(&mut self, uplo: UPLO) -> Result<&mut Self> {
260 A::cholesky(self.square_layout()?, uplo, self.as_allocated_mut()?)?;
261 Ok(self.into_triangular(uplo))
262 }
263}
264
265/// Cholesky decomposition of Hermitian (or real symmetric) positive definite matrix reference
266pub trait FactorizeC<S: Data> {
267 /// Computes the Cholesky decomposition of the Hermitian (or real
268 /// symmetric) positive definite matrix.
269 ///
270 /// If the argument is `UPLO::Upper`, then computes the decomposition `A =
271 /// U^H * U` using the upper triangular portion of `A` and returns the
272 /// factorization containing `U`. Otherwise, if the argument is
273 /// `UPLO::Lower`, computes the decomposition `A = L * L^H` using the lower
274 /// triangular portion of `A` and returns the factorization containing `L`.
275 fn factorizec(&self, uplo: UPLO) -> Result<CholeskyFactorized<S>>;
276}
277
278/// Cholesky decomposition of Hermitian (or real symmetric) positive definite matrix
279pub trait FactorizeCInto<S: Data> {
280 /// Computes the Cholesky decomposition of the Hermitian (or real
281 /// symmetric) positive definite matrix.
282 ///
283 /// If the argument is `UPLO::Upper`, then computes the decomposition `A =
284 /// U^H * U` using the upper triangular portion of `A` and returns the
285 /// factorization containing `U`. Otherwise, if the argument is
286 /// `UPLO::Lower`, computes the decomposition `A = L * L^H` using the lower
287 /// triangular portion of `A` and returns the factorization containing `L`.
288 fn factorizec_into(self, uplo: UPLO) -> Result<CholeskyFactorized<S>>;
289}
290
291impl<A, S> FactorizeCInto<S> for ArrayBase<S, Ix2>
292where
293 A: Scalar + Lapack,
294 S: DataMut<Elem = A>,
295{
296 fn factorizec_into(self, uplo: UPLO) -> Result<CholeskyFactorized<S>> {
297 Ok(CholeskyFactorized {
298 factor: self.cholesky_into(uplo)?,
299 uplo,
300 })
301 }
302}
303
304impl<A, Si> FactorizeC<OwnedRepr<A>> for ArrayBase<Si, Ix2>
305where
306 A: Scalar + Lapack,
307 Si: Data<Elem = A>,
308{
309 fn factorizec(&self, uplo: UPLO) -> Result<CholeskyFactorized<OwnedRepr<A>>> {
310 Ok(CholeskyFactorized {
311 factor: self.cholesky(uplo)?,
312 uplo,
313 })
314 }
315}
316
317/// Solve systems of linear equations with Hermitian (or real symmetric)
318/// positive definite coefficient matrices
319pub trait SolveC<A: Scalar> {
320 /// Solves a system of linear equations `A * x = b` with Hermitian (or real
321 /// symmetric) positive definite matrix `A`, where `A` is `self`, `b` is
322 /// the argument, and `x` is the successful result.
323 fn solvec<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
324 let mut b = replicate(b);
325 self.solvec_inplace(&mut b)?;
326 Ok(b)
327 }
328 /// Solves a system of linear equations `A * x = b` with Hermitian (or real
329 /// symmetric) positive definite matrix `A`, where `A` is `self`, `b` is
330 /// the argument, and `x` is the successful result.
331 fn solvec_into<S: DataMut<Elem = A>>(
332 &self,
333 mut b: ArrayBase<S, Ix1>,
334 ) -> Result<ArrayBase<S, Ix1>> {
335 self.solvec_inplace(&mut b)?;
336 Ok(b)
337 }
338 /// Solves a system of linear equations `A * x = b` with Hermitian (or real
339 /// symmetric) positive definite matrix `A`, where `A` is `self`, `b` is
340 /// the argument, and `x` is the successful result. The value of `x` is
341 /// also assigned to the argument.
342 fn solvec_inplace<'a, S: DataMut<Elem = A>>(
343 &self,
344 b: &'a mut ArrayBase<S, Ix1>,
345 ) -> Result<&'a mut ArrayBase<S, Ix1>>;
346}
347
348impl<A, S> SolveC<A> for ArrayBase<S, Ix2>
349where
350 A: Scalar + Lapack,
351 S: Data<Elem = A>,
352{
353 fn solvec_inplace<'a, Sb>(
354 &self,
355 b: &'a mut ArrayBase<Sb, Ix1>,
356 ) -> Result<&'a mut ArrayBase<Sb, Ix1>>
357 where
358 Sb: DataMut<Elem = A>,
359 {
360 self.factorizec(UPLO::Upper)?.solvec_inplace(b)
361 }
362}
363
364/// Inverse of Hermitian (or real symmetric) positive definite matrix ref
365pub trait InverseC {
366 type Output;
367 /// Computes the inverse of the Hermitian (or real symmetric) positive
368 /// definite matrix.
369 fn invc(&self) -> Result<Self::Output>;
370}
371
372/// Inverse of Hermitian (or real symmetric) positive definite matrix
373pub trait InverseCInto {
374 type Output;
375 /// Computes the inverse of the Hermitian (or real symmetric) positive
376 /// definite matrix.
377 fn invc_into(self) -> Result<Self::Output>;
378}
379
380impl<A, S> InverseC for ArrayBase<S, Ix2>
381where
382 A: Scalar + Lapack,
383 S: Data<Elem = A>,
384{
385 type Output = Array2<A>;
386
387 fn invc(&self) -> Result<Self::Output> {
388 self.factorizec(UPLO::Upper)?.invc_into()
389 }
390}
391
392impl<A, S> InverseCInto for ArrayBase<S, Ix2>
393where
394 A: Scalar + Lapack,
395 S: DataMut<Elem = A>,
396{
397 type Output = Self;
398
399 fn invc_into(self) -> Result<Self::Output> {
400 self.factorizec_into(UPLO::Upper)?.invc_into()
401 }
402}
403
404/// Determinant of Hermitian (or real symmetric) positive definite matrix ref
405pub trait DeterminantC {
406 type Output;
407
408 /// Computes the determinant of the Hermitian (or real symmetric) positive
409 /// definite matrix.
410 fn detc(&self) -> Self::Output;
411
412 /// Computes the natural log of the determinant of the Hermitian (or real
413 /// symmetric) positive definite matrix.
414 ///
415 /// This method is more robust than `.detc()` to very small or very large
416 /// determinants since it returns the natural logarithm of the determinant
417 /// rather than the determinant itself.
418 fn ln_detc(&self) -> Self::Output;
419}
420
421/// Determinant of Hermitian (or real symmetric) positive definite matrix
422pub trait DeterminantCInto {
423 type Output;
424
425 /// Computes the determinant of the Hermitian (or real symmetric) positive
426 /// definite matrix.
427 fn detc_into(self) -> Self::Output;
428
429 /// Computes the natural log of the determinant of the Hermitian (or real
430 /// symmetric) positive definite matrix.
431 ///
432 /// This method is more robust than `.detc_into()` to very small or very
433 /// large determinants since it returns the natural logarithm of the
434 /// determinant rather than the determinant itself.
435 fn ln_detc_into(self) -> Self::Output;
436}
437
438impl<A, S> DeterminantC for ArrayBase<S, Ix2>
439where
440 A: Scalar + Lapack,
441 S: Data<Elem = A>,
442{
443 type Output = Result<<A as Scalar>::Real>;
444
445 fn detc(&self) -> Self::Output {
446 Ok(Float::exp(self.ln_detc()?))
447 }
448
449 fn ln_detc(&self) -> Self::Output {
450 Ok(self.factorizec(UPLO::Upper)?.ln_detc())
451 }
452}
453
454impl<A, S> DeterminantCInto for ArrayBase<S, Ix2>
455where
456 A: Scalar + Lapack,
457 S: DataMut<Elem = A>,
458{
459 type Output = Result<<A as Scalar>::Real>;
460
461 fn detc_into(self) -> Self::Output {
462 Ok(Float::exp(self.ln_detc_into()?))
463 }
464
465 fn ln_detc_into(self) -> Self::Output {
466 Ok(self.factorizec_into(UPLO::Upper)?.ln_detc_into())
467 }
468}