lax/
eigh_generalized.rs

1//! Generalized eigenvalue problem for symmetric/Hermitian matrices
2//!
3//! LAPACK correspondance
4//! ----------------------
5//!
6//! | f32   | f64   | c32   | c64   |
7//! |:------|:------|:------|:------|
8//! | ssygv | dsygv | chegv | zhegv |
9//!
10
11use super::*;
12use crate::{error::*, layout::MatrixLayout};
13use cauchy::*;
14use num_traits::{ToPrimitive, Zero};
15
16pub struct EighGeneralizedWork<T: Scalar> {
17    pub n: i32,
18    pub jobz: JobEv,
19    pub eigs: Vec<MaybeUninit<T::Real>>,
20    pub work: Vec<MaybeUninit<T>>,
21    pub rwork: Option<Vec<MaybeUninit<T::Real>>>,
22}
23
24pub trait EighGeneralizedWorkImpl: Sized {
25    type Elem: Scalar;
26    fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self>;
27    fn calc(
28        &mut self,
29        uplo: UPLO,
30        a: &mut [Self::Elem],
31        b: &mut [Self::Elem],
32    ) -> Result<&[<Self::Elem as Scalar>::Real]>;
33    fn eval(
34        self,
35        uplo: UPLO,
36        a: &mut [Self::Elem],
37        b: &mut [Self::Elem],
38    ) -> Result<Vec<<Self::Elem as Scalar>::Real>>;
39}
40
41macro_rules! impl_eigh_generalized_work_c {
42    ($c:ty, $gv:path) => {
43        impl EighGeneralizedWorkImpl for EighGeneralizedWork<$c> {
44            type Elem = $c;
45
46            fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self> {
47                assert_eq!(layout.len(), layout.lda());
48                let n = layout.len();
49                let jobz = if calc_eigenvectors {
50                    JobEv::All
51                } else {
52                    JobEv::None
53                };
54                let mut eigs = vec_uninit(n as usize);
55                let mut rwork = vec_uninit(3 * n as usize - 2 as usize);
56                let mut info = 0;
57                let mut work_size = [Self::Elem::zero()];
58                unsafe {
59                    $gv(
60                        &1, // ITYPE A*x = (lambda)*B*x
61                        jobz.as_ptr(),
62                        UPLO::Upper.as_ptr(), // dummy, working memory is not affected by UPLO
63                        &n,
64                        std::ptr::null_mut(),
65                        &n,
66                        std::ptr::null_mut(),
67                        &n,
68                        AsPtr::as_mut_ptr(&mut eigs),
69                        AsPtr::as_mut_ptr(&mut work_size),
70                        &(-1),
71                        AsPtr::as_mut_ptr(&mut rwork),
72                        &mut info,
73                    );
74                }
75                info.as_lapack_result()?;
76                let lwork = work_size[0].to_usize().unwrap();
77                let work = vec_uninit(lwork);
78                Ok(EighGeneralizedWork {
79                    n,
80                    eigs,
81                    jobz,
82                    work,
83                    rwork: Some(rwork),
84                })
85            }
86
87            fn calc(
88                &mut self,
89                uplo: UPLO,
90                a: &mut [Self::Elem],
91                b: &mut [Self::Elem],
92            ) -> Result<&[<Self::Elem as Scalar>::Real]> {
93                let lwork = self.work.len().to_i32().unwrap();
94                let mut info = 0;
95                unsafe {
96                    $gv(
97                        &1, // ITYPE A*x = (lambda)*B*x
98                        self.jobz.as_ptr(),
99                        uplo.as_ptr(),
100                        &self.n,
101                        AsPtr::as_mut_ptr(a),
102                        &self.n,
103                        AsPtr::as_mut_ptr(b),
104                        &self.n,
105                        AsPtr::as_mut_ptr(&mut self.eigs),
106                        AsPtr::as_mut_ptr(&mut self.work),
107                        &lwork,
108                        AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()),
109                        &mut info,
110                    );
111                }
112                info.as_lapack_result()?;
113                Ok(unsafe { self.eigs.slice_assume_init_ref() })
114            }
115
116            fn eval(
117                mut self,
118                uplo: UPLO,
119                a: &mut [Self::Elem],
120                b: &mut [Self::Elem],
121            ) -> Result<Vec<<Self::Elem as Scalar>::Real>> {
122                let _eig = self.calc(uplo, a, b)?;
123                Ok(unsafe { self.eigs.assume_init() })
124            }
125        }
126    };
127}
128impl_eigh_generalized_work_c!(c64, lapack_sys::zhegv_);
129impl_eigh_generalized_work_c!(c32, lapack_sys::chegv_);
130
131macro_rules! impl_eigh_generalized_work_r {
132    ($f:ty, $gv:path) => {
133        impl EighGeneralizedWorkImpl for EighGeneralizedWork<$f> {
134            type Elem = $f;
135
136            fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self> {
137                assert_eq!(layout.len(), layout.lda());
138                let n = layout.len();
139                let jobz = if calc_eigenvectors {
140                    JobEv::All
141                } else {
142                    JobEv::None
143                };
144                let mut eigs = vec_uninit(n as usize);
145                let mut info = 0;
146                let mut work_size = [Self::Elem::zero()];
147                unsafe {
148                    $gv(
149                        &1, // ITYPE A*x = (lambda)*B*x
150                        jobz.as_ptr(),
151                        UPLO::Upper.as_ptr(), // dummy, working memory is not affected by UPLO
152                        &n,
153                        std::ptr::null_mut(),
154                        &n,
155                        std::ptr::null_mut(),
156                        &n,
157                        AsPtr::as_mut_ptr(&mut eigs),
158                        AsPtr::as_mut_ptr(&mut work_size),
159                        &(-1),
160                        &mut info,
161                    );
162                }
163                info.as_lapack_result()?;
164                let lwork = work_size[0].to_usize().unwrap();
165                let work = vec_uninit(lwork);
166                Ok(EighGeneralizedWork {
167                    n,
168                    eigs,
169                    jobz,
170                    work,
171                    rwork: None,
172                })
173            }
174
175            fn calc(
176                &mut self,
177                uplo: UPLO,
178                a: &mut [Self::Elem],
179                b: &mut [Self::Elem],
180            ) -> Result<&[<Self::Elem as Scalar>::Real]> {
181                let lwork = self.work.len().to_i32().unwrap();
182                let mut info = 0;
183                unsafe {
184                    $gv(
185                        &1, // ITYPE A*x = (lambda)*B*x
186                        self.jobz.as_ptr(),
187                        uplo.as_ptr(),
188                        &self.n,
189                        AsPtr::as_mut_ptr(a),
190                        &self.n,
191                        AsPtr::as_mut_ptr(b),
192                        &self.n,
193                        AsPtr::as_mut_ptr(&mut self.eigs),
194                        AsPtr::as_mut_ptr(&mut self.work),
195                        &lwork,
196                        &mut info,
197                    );
198                }
199                info.as_lapack_result()?;
200                Ok(unsafe { self.eigs.slice_assume_init_ref() })
201            }
202
203            fn eval(
204                mut self,
205                uplo: UPLO,
206                a: &mut [Self::Elem],
207                b: &mut [Self::Elem],
208            ) -> Result<Vec<<Self::Elem as Scalar>::Real>> {
209                let _eig = self.calc(uplo, a, b)?;
210                Ok(unsafe { self.eigs.assume_init() })
211            }
212        }
213    };
214}
215impl_eigh_generalized_work_r!(f64, lapack_sys::dsygv_);
216impl_eigh_generalized_work_r!(f32, lapack_sys::ssygv_);