lax/
cholesky.rs

1//! Factorize positive-definite symmetric/Hermitian matrices using Cholesky algorithm
2
3use super::*;
4use crate::{error::*, layout::*};
5use cauchy::*;
6
7/// Compute Cholesky decomposition according to [UPLO]
8///
9/// LAPACK correspondance
10/// ----------------------
11///
12/// | f32    | f64    | c32    | c64    |
13/// |:-------|:-------|:-------|:-------|
14/// | spotrf | dpotrf | cpotrf | zpotrf |
15///
16pub trait CholeskyImpl: Scalar {
17    fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
18}
19
20macro_rules! impl_cholesky_ {
21    ($s:ty, $trf:path) => {
22        impl CholeskyImpl for $s {
23            fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
24                let (n, _) = l.size();
25                if matches!(l, MatrixLayout::C { .. }) {
26                    square_transpose(l, a);
27                }
28                let mut info = 0;
29                unsafe {
30                    $trf(uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &n, &mut info);
31                }
32                info.as_lapack_result()?;
33                if matches!(l, MatrixLayout::C { .. }) {
34                    square_transpose(l, a);
35                }
36                Ok(())
37            }
38        }
39    };
40}
41impl_cholesky_!(c64, lapack_sys::zpotrf_);
42impl_cholesky_!(c32, lapack_sys::cpotrf_);
43impl_cholesky_!(f64, lapack_sys::dpotrf_);
44impl_cholesky_!(f32, lapack_sys::spotrf_);
45
46/// Compute inverse matrix using Cholesky factroization result
47///
48/// LAPACK correspondance
49/// ----------------------
50///
51/// | f32    | f64    | c32    | c64    |
52/// |:-------|:-------|:-------|:-------|
53/// | spotri | dpotri | cpotri | zpotri |
54///
55pub trait InvCholeskyImpl: Scalar {
56    fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
57}
58
59macro_rules! impl_inv_cholesky {
60    ($s:ty, $tri:path) => {
61        impl InvCholeskyImpl for $s {
62            fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
63                let (n, _) = l.size();
64                if matches!(l, MatrixLayout::C { .. }) {
65                    square_transpose(l, a);
66                }
67                let mut info = 0;
68                unsafe {
69                    $tri(uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &l.lda(), &mut info);
70                }
71                info.as_lapack_result()?;
72                if matches!(l, MatrixLayout::C { .. }) {
73                    square_transpose(l, a);
74                }
75                Ok(())
76            }
77        }
78    };
79}
80impl_inv_cholesky!(c64, lapack_sys::zpotri_);
81impl_inv_cholesky!(c32, lapack_sys::cpotri_);
82impl_inv_cholesky!(f64, lapack_sys::dpotri_);
83impl_inv_cholesky!(f32, lapack_sys::spotri_);
84
85/// Solve linear equation using Cholesky factroization result
86///
87/// LAPACK correspondance
88/// ----------------------
89///
90/// | f32    | f64    | c32    | c64    |
91/// |:-------|:-------|:-------|:-------|
92/// | spotrs | dpotrs | cpotrs | zpotrs |
93///
94pub trait SolveCholeskyImpl: Scalar {
95    fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>;
96}
97
98macro_rules! impl_solve_cholesky {
99    ($s:ty, $trs:path) => {
100        impl SolveCholeskyImpl for $s {
101            fn solve_cholesky(
102                l: MatrixLayout,
103                mut uplo: UPLO,
104                a: &[Self],
105                b: &mut [Self],
106            ) -> Result<()> {
107                let (n, _) = l.size();
108                let nrhs = 1;
109                let mut info = 0;
110                if matches!(l, MatrixLayout::C { .. }) {
111                    uplo = uplo.t();
112                    for val in b.iter_mut() {
113                        *val = val.conj();
114                    }
115                }
116                unsafe {
117                    $trs(
118                        uplo.as_ptr(),
119                        &n,
120                        &nrhs,
121                        AsPtr::as_ptr(a),
122                        &l.lda(),
123                        AsPtr::as_mut_ptr(b),
124                        &n,
125                        &mut info,
126                    );
127                }
128                info.as_lapack_result()?;
129                if matches!(l, MatrixLayout::C { .. }) {
130                    for val in b.iter_mut() {
131                        *val = val.conj();
132                    }
133                }
134                Ok(())
135            }
136        }
137    };
138}
139impl_solve_cholesky!(c64, lapack_sys::zpotrs_);
140impl_solve_cholesky!(c32, lapack_sys::cpotrs_);
141impl_solve_cholesky!(f64, lapack_sys::dpotrs_);
142impl_solve_cholesky!(f32, lapack_sys::spotrs_);