1use super::*;
4use crate::{error::*, layout::*};
5use cauchy::*;
6
7pub trait CholeskyImpl: Scalar {
17 fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
18}
19
20macro_rules! impl_cholesky_ {
21 ($s:ty, $trf:path) => {
22 impl CholeskyImpl for $s {
23 fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
24 let (n, _) = l.size();
25 if matches!(l, MatrixLayout::C { .. }) {
26 square_transpose(l, a);
27 }
28 let mut info = 0;
29 unsafe {
30 $trf(uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &n, &mut info);
31 }
32 info.as_lapack_result()?;
33 if matches!(l, MatrixLayout::C { .. }) {
34 square_transpose(l, a);
35 }
36 Ok(())
37 }
38 }
39 };
40}
41impl_cholesky_!(c64, lapack_sys::zpotrf_);
42impl_cholesky_!(c32, lapack_sys::cpotrf_);
43impl_cholesky_!(f64, lapack_sys::dpotrf_);
44impl_cholesky_!(f32, lapack_sys::spotrf_);
45
46pub trait InvCholeskyImpl: Scalar {
56 fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
57}
58
59macro_rules! impl_inv_cholesky {
60 ($s:ty, $tri:path) => {
61 impl InvCholeskyImpl for $s {
62 fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
63 let (n, _) = l.size();
64 if matches!(l, MatrixLayout::C { .. }) {
65 square_transpose(l, a);
66 }
67 let mut info = 0;
68 unsafe {
69 $tri(uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &l.lda(), &mut info);
70 }
71 info.as_lapack_result()?;
72 if matches!(l, MatrixLayout::C { .. }) {
73 square_transpose(l, a);
74 }
75 Ok(())
76 }
77 }
78 };
79}
80impl_inv_cholesky!(c64, lapack_sys::zpotri_);
81impl_inv_cholesky!(c32, lapack_sys::cpotri_);
82impl_inv_cholesky!(f64, lapack_sys::dpotri_);
83impl_inv_cholesky!(f32, lapack_sys::spotri_);
84
85pub trait SolveCholeskyImpl: Scalar {
95 fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>;
96}
97
98macro_rules! impl_solve_cholesky {
99 ($s:ty, $trs:path) => {
100 impl SolveCholeskyImpl for $s {
101 fn solve_cholesky(
102 l: MatrixLayout,
103 mut uplo: UPLO,
104 a: &[Self],
105 b: &mut [Self],
106 ) -> Result<()> {
107 let (n, _) = l.size();
108 let nrhs = 1;
109 let mut info = 0;
110 if matches!(l, MatrixLayout::C { .. }) {
111 uplo = uplo.t();
112 for val in b.iter_mut() {
113 *val = val.conj();
114 }
115 }
116 unsafe {
117 $trs(
118 uplo.as_ptr(),
119 &n,
120 &nrhs,
121 AsPtr::as_ptr(a),
122 &l.lda(),
123 AsPtr::as_mut_ptr(b),
124 &n,
125 &mut info,
126 );
127 }
128 info.as_lapack_result()?;
129 if matches!(l, MatrixLayout::C { .. }) {
130 for val in b.iter_mut() {
131 *val = val.conj();
132 }
133 }
134 Ok(())
135 }
136 }
137 };
138}
139impl_solve_cholesky!(c64, lapack_sys::zpotrs_);
140impl_solve_cholesky!(c32, lapack_sys::cpotrs_);
141impl_solve_cholesky!(f64, lapack_sys::dpotrs_);
142impl_solve_cholesky!(f32, lapack_sys::spotrs_);