lax/
eigh.rs

1//! Eigenvalue problem for symmetric/Hermitian matricies
2//!
3//! LAPACK correspondance
4//! ----------------------
5//!
6//! | f32   | f64   | c32   | c64   |
7//! |:------|:------|:------|:------|
8//! | ssyev | dsyev | cheev | zheev |
9
10use super::*;
11use crate::{error::*, layout::MatrixLayout};
12use cauchy::*;
13use num_traits::{ToPrimitive, Zero};
14
15pub struct EighWork<T: Scalar> {
16    pub n: i32,
17    pub jobz: JobEv,
18    pub eigs: Vec<MaybeUninit<T::Real>>,
19    pub work: Vec<MaybeUninit<T>>,
20    pub rwork: Option<Vec<MaybeUninit<T::Real>>>,
21}
22
23pub trait EighWorkImpl: Sized {
24    type Elem: Scalar;
25    fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self>;
26    fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem])
27        -> Result<&[<Self::Elem as Scalar>::Real]>;
28    fn eval(self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<Vec<<Self::Elem as Scalar>::Real>>;
29}
30
31macro_rules! impl_eigh_work_c {
32    ($c:ty, $ev:path) => {
33        impl EighWorkImpl for EighWork<$c> {
34            type Elem = $c;
35
36            fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self> {
37                assert_eq!(layout.len(), layout.lda());
38                let n = layout.len();
39                let jobz = if calc_eigenvectors {
40                    JobEv::All
41                } else {
42                    JobEv::None
43                };
44                let mut eigs = vec_uninit(n as usize);
45                let mut rwork = vec_uninit(3 * n as usize - 2 as usize);
46                let mut info = 0;
47                let mut work_size = [Self::Elem::zero()];
48                unsafe {
49                    $ev(
50                        jobz.as_ptr(),
51                        UPLO::Upper.as_ptr(), // dummy, working memory is not affected by UPLO
52                        &n,
53                        std::ptr::null_mut(),
54                        &n,
55                        AsPtr::as_mut_ptr(&mut eigs),
56                        AsPtr::as_mut_ptr(&mut work_size),
57                        &(-1),
58                        AsPtr::as_mut_ptr(&mut rwork),
59                        &mut info,
60                    );
61                }
62                info.as_lapack_result()?;
63                let lwork = work_size[0].to_usize().unwrap();
64                let work = vec_uninit(lwork);
65                Ok(EighWork {
66                    n,
67                    eigs,
68                    jobz,
69                    work,
70                    rwork: Some(rwork),
71                })
72            }
73
74            fn calc(
75                &mut self,
76                uplo: UPLO,
77                a: &mut [Self::Elem],
78            ) -> Result<&[<Self::Elem as Scalar>::Real]> {
79                let lwork = self.work.len().to_i32().unwrap();
80                let mut info = 0;
81                unsafe {
82                    $ev(
83                        self.jobz.as_ptr(),
84                        uplo.as_ptr(),
85                        &self.n,
86                        AsPtr::as_mut_ptr(a),
87                        &self.n,
88                        AsPtr::as_mut_ptr(&mut self.eigs),
89                        AsPtr::as_mut_ptr(&mut self.work),
90                        &lwork,
91                        AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()),
92                        &mut info,
93                    );
94                }
95                info.as_lapack_result()?;
96                Ok(unsafe { self.eigs.slice_assume_init_ref() })
97            }
98
99            fn eval(
100                mut self,
101                uplo: UPLO,
102                a: &mut [Self::Elem],
103            ) -> Result<Vec<<Self::Elem as Scalar>::Real>> {
104                let _eig = self.calc(uplo, a)?;
105                Ok(unsafe { self.eigs.assume_init() })
106            }
107        }
108    };
109}
110impl_eigh_work_c!(c64, lapack_sys::zheev_);
111impl_eigh_work_c!(c32, lapack_sys::cheev_);
112
113macro_rules! impl_eigh_work_r {
114    ($f:ty, $ev:path) => {
115        impl EighWorkImpl for EighWork<$f> {
116            type Elem = $f;
117
118            fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self> {
119                assert_eq!(layout.len(), layout.lda());
120                let n = layout.len();
121                let jobz = if calc_eigenvectors {
122                    JobEv::All
123                } else {
124                    JobEv::None
125                };
126                let mut eigs = vec_uninit(n as usize);
127                let mut info = 0;
128                let mut work_size = [Self::Elem::zero()];
129                unsafe {
130                    $ev(
131                        jobz.as_ptr(),
132                        UPLO::Upper.as_ptr(), // dummy, working memory is not affected by UPLO
133                        &n,
134                        std::ptr::null_mut(),
135                        &n,
136                        AsPtr::as_mut_ptr(&mut eigs),
137                        AsPtr::as_mut_ptr(&mut work_size),
138                        &(-1),
139                        &mut info,
140                    );
141                }
142                info.as_lapack_result()?;
143                let lwork = work_size[0].to_usize().unwrap();
144                let work = vec_uninit(lwork);
145                Ok(EighWork {
146                    n,
147                    eigs,
148                    jobz,
149                    work,
150                    rwork: None,
151                })
152            }
153
154            fn calc(
155                &mut self,
156                uplo: UPLO,
157                a: &mut [Self::Elem],
158            ) -> Result<&[<Self::Elem as Scalar>::Real]> {
159                let lwork = self.work.len().to_i32().unwrap();
160                let mut info = 0;
161                unsafe {
162                    $ev(
163                        self.jobz.as_ptr(),
164                        uplo.as_ptr(),
165                        &self.n,
166                        AsPtr::as_mut_ptr(a),
167                        &self.n,
168                        AsPtr::as_mut_ptr(&mut self.eigs),
169                        AsPtr::as_mut_ptr(&mut self.work),
170                        &lwork,
171                        &mut info,
172                    );
173                }
174                info.as_lapack_result()?;
175                Ok(unsafe { self.eigs.slice_assume_init_ref() })
176            }
177
178            fn eval(
179                mut self,
180                uplo: UPLO,
181                a: &mut [Self::Elem],
182            ) -> Result<Vec<<Self::Elem as Scalar>::Real>> {
183                let _eig = self.calc(uplo, a)?;
184                Ok(unsafe { self.eigs.assume_init() })
185            }
186        }
187    };
188}
189impl_eigh_work_r!(f64, lapack_sys::dsyev_);
190impl_eigh_work_r!(f32, lapack_sys::ssyev_);