lax/
solveh.rs

1//! Factorize symmetric/Hermitian matrix using [Bunch-Kaufman diagonal pivoting method][BK]
2//!
3//! [BK]: https://doi.org/10.2307/2005787
4//!
5
6use crate::{error::*, layout::MatrixLayout, *};
7use cauchy::*;
8use num_traits::{ToPrimitive, Zero};
9
10pub struct BkWork<T: Scalar> {
11    pub layout: MatrixLayout,
12    pub work: Vec<MaybeUninit<T>>,
13    pub ipiv: Vec<MaybeUninit<i32>>,
14}
15
16/// Factorize symmetric/Hermitian matrix using Bunch-Kaufman diagonal pivoting method
17///
18/// LAPACK correspondance
19/// ----------------------
20///
21/// | f32    | f64    | c32    | c64    |
22/// |:-------|:-------|:-------|:-------|
23/// | ssytrf | dsytrf | chetrf | zhetrf |
24///
25pub trait BkWorkImpl: Sized {
26    type Elem: Scalar;
27    fn new(l: MatrixLayout) -> Result<Self>;
28    fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<&[i32]>;
29    fn eval(self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<Pivot>;
30}
31
32macro_rules! impl_bk_work {
33    ($s:ty, $trf:path) => {
34        impl BkWorkImpl for BkWork<$s> {
35            type Elem = $s;
36
37            fn new(layout: MatrixLayout) -> Result<Self> {
38                let (n, _) = layout.size();
39                let ipiv = vec_uninit(n as usize);
40                let mut info = 0;
41                let mut work_size = [Self::Elem::zero()];
42                unsafe {
43                    $trf(
44                        UPLO::Upper.as_ptr(),
45                        &n,
46                        std::ptr::null_mut(),
47                        &layout.lda(),
48                        std::ptr::null_mut(),
49                        AsPtr::as_mut_ptr(&mut work_size),
50                        &(-1),
51                        &mut info,
52                    )
53                };
54                info.as_lapack_result()?;
55                let lwork = work_size[0].to_usize().unwrap();
56                let work = vec_uninit(lwork);
57                Ok(BkWork { layout, work, ipiv })
58            }
59
60            fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<&[i32]> {
61                let (n, _) = self.layout.size();
62                let lwork = self.work.len().to_i32().unwrap();
63                if lwork == 0 {
64                    return Ok(&[]);
65                }
66                let mut info = 0;
67                unsafe {
68                    $trf(
69                        uplo.as_ptr(),
70                        &n,
71                        AsPtr::as_mut_ptr(a),
72                        &self.layout.lda(),
73                        AsPtr::as_mut_ptr(&mut self.ipiv),
74                        AsPtr::as_mut_ptr(&mut self.work),
75                        &lwork,
76                        &mut info,
77                    )
78                };
79                info.as_lapack_result()?;
80                Ok(unsafe { self.ipiv.slice_assume_init_ref() })
81            }
82
83            fn eval(mut self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<Pivot> {
84                let _ref = self.calc(uplo, a)?;
85                Ok(unsafe { self.ipiv.assume_init() })
86            }
87        }
88    };
89}
90impl_bk_work!(c64, lapack_sys::zhetrf_);
91impl_bk_work!(c32, lapack_sys::chetrf_);
92impl_bk_work!(f64, lapack_sys::dsytrf_);
93impl_bk_work!(f32, lapack_sys::ssytrf_);
94
95pub struct InvhWork<T: Scalar> {
96    pub layout: MatrixLayout,
97    pub work: Vec<MaybeUninit<T>>,
98}
99
100/// Compute inverse matrix of symmetric/Hermitian matrix
101///
102/// LAPACK correspondance
103/// ----------------------
104///
105/// | f32    | f64    | c32    | c64    |
106/// |:-------|:-------|:-------|:-------|
107/// | ssytri | dsytri | chetri | zhetri |
108///
109pub trait InvhWorkImpl: Sized {
110    type Elem;
111    fn new(layout: MatrixLayout) -> Result<Self>;
112    fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()>;
113}
114
115macro_rules! impl_invh_work {
116    ($s:ty, $tri:path) => {
117        impl InvhWorkImpl for InvhWork<$s> {
118            type Elem = $s;
119
120            fn new(layout: MatrixLayout) -> Result<Self> {
121                let (n, _) = layout.size();
122                let work = vec_uninit(n as usize);
123                Ok(InvhWork { layout, work })
124            }
125
126            fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()> {
127                let (n, _) = self.layout.size();
128                let mut info = 0;
129                unsafe {
130                    $tri(
131                        uplo.as_ptr(),
132                        &n,
133                        AsPtr::as_mut_ptr(a),
134                        &self.layout.lda(),
135                        ipiv.as_ptr(),
136                        AsPtr::as_mut_ptr(&mut self.work),
137                        &mut info,
138                    )
139                };
140                info.as_lapack_result()?;
141                Ok(())
142            }
143        }
144    };
145}
146impl_invh_work!(c64, lapack_sys::zhetri_);
147impl_invh_work!(c32, lapack_sys::chetri_);
148impl_invh_work!(f64, lapack_sys::dsytri_);
149impl_invh_work!(f32, lapack_sys::ssytri_);
150
151/// Solve symmetric/Hermitian linear equation
152///
153/// LAPACK correspondance
154/// ----------------------
155///
156/// | f32    | f64    | c32    | c64    |
157/// |:-------|:-------|:-------|:-------|
158/// | ssytrs | dsytrs | chetrs | zhetrs |
159///
160pub trait SolvehImpl: Scalar {
161    fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>;
162}
163
164macro_rules! impl_solveh_ {
165    ($s:ty, $trs:path) => {
166        impl SolvehImpl for $s {
167            fn solveh(
168                l: MatrixLayout,
169                uplo: UPLO,
170                a: &[Self],
171                ipiv: &Pivot,
172                b: &mut [Self],
173            ) -> Result<()> {
174                let (n, _) = l.size();
175                let mut info = 0;
176                unsafe {
177                    $trs(
178                        uplo.as_ptr(),
179                        &n,
180                        &1,
181                        AsPtr::as_ptr(a),
182                        &l.lda(),
183                        ipiv.as_ptr(),
184                        AsPtr::as_mut_ptr(b),
185                        &n,
186                        &mut info,
187                    )
188                };
189                info.as_lapack_result()?;
190                Ok(())
191            }
192        }
193    };
194}
195
196impl_solveh_!(c64, lapack_sys::zhetrs_);
197impl_solveh_!(c32, lapack_sys::chetrs_);
198impl_solveh_!(f64, lapack_sys::dsytrs_);
199impl_solveh_!(f32, lapack_sys::ssytrs_);