1use super::*;
12use crate::{error::*, layout::MatrixLayout};
13use cauchy::*;
14use num_traits::{ToPrimitive, Zero};
15
16pub struct EighGeneralizedWork<T: Scalar> {
17 pub n: i32,
18 pub jobz: JobEv,
19 pub eigs: Vec<MaybeUninit<T::Real>>,
20 pub work: Vec<MaybeUninit<T>>,
21 pub rwork: Option<Vec<MaybeUninit<T::Real>>>,
22}
23
24pub trait EighGeneralizedWorkImpl: Sized {
25 type Elem: Scalar;
26 fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self>;
27 fn calc(
28 &mut self,
29 uplo: UPLO,
30 a: &mut [Self::Elem],
31 b: &mut [Self::Elem],
32 ) -> Result<&[<Self::Elem as Scalar>::Real]>;
33 fn eval(
34 self,
35 uplo: UPLO,
36 a: &mut [Self::Elem],
37 b: &mut [Self::Elem],
38 ) -> Result<Vec<<Self::Elem as Scalar>::Real>>;
39}
40
41macro_rules! impl_eigh_generalized_work_c {
42 ($c:ty, $gv:path) => {
43 impl EighGeneralizedWorkImpl for EighGeneralizedWork<$c> {
44 type Elem = $c;
45
46 fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self> {
47 assert_eq!(layout.len(), layout.lda());
48 let n = layout.len();
49 let jobz = if calc_eigenvectors {
50 JobEv::All
51 } else {
52 JobEv::None
53 };
54 let mut eigs = vec_uninit(n as usize);
55 let mut rwork = vec_uninit(3 * n as usize - 2 as usize);
56 let mut info = 0;
57 let mut work_size = [Self::Elem::zero()];
58 unsafe {
59 $gv(
60 &1, jobz.as_ptr(),
62 UPLO::Upper.as_ptr(), &n,
64 std::ptr::null_mut(),
65 &n,
66 std::ptr::null_mut(),
67 &n,
68 AsPtr::as_mut_ptr(&mut eigs),
69 AsPtr::as_mut_ptr(&mut work_size),
70 &(-1),
71 AsPtr::as_mut_ptr(&mut rwork),
72 &mut info,
73 );
74 }
75 info.as_lapack_result()?;
76 let lwork = work_size[0].to_usize().unwrap();
77 let work = vec_uninit(lwork);
78 Ok(EighGeneralizedWork {
79 n,
80 eigs,
81 jobz,
82 work,
83 rwork: Some(rwork),
84 })
85 }
86
87 fn calc(
88 &mut self,
89 uplo: UPLO,
90 a: &mut [Self::Elem],
91 b: &mut [Self::Elem],
92 ) -> Result<&[<Self::Elem as Scalar>::Real]> {
93 let lwork = self.work.len().to_i32().unwrap();
94 let mut info = 0;
95 unsafe {
96 $gv(
97 &1, self.jobz.as_ptr(),
99 uplo.as_ptr(),
100 &self.n,
101 AsPtr::as_mut_ptr(a),
102 &self.n,
103 AsPtr::as_mut_ptr(b),
104 &self.n,
105 AsPtr::as_mut_ptr(&mut self.eigs),
106 AsPtr::as_mut_ptr(&mut self.work),
107 &lwork,
108 AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()),
109 &mut info,
110 );
111 }
112 info.as_lapack_result()?;
113 Ok(unsafe { self.eigs.slice_assume_init_ref() })
114 }
115
116 fn eval(
117 mut self,
118 uplo: UPLO,
119 a: &mut [Self::Elem],
120 b: &mut [Self::Elem],
121 ) -> Result<Vec<<Self::Elem as Scalar>::Real>> {
122 let _eig = self.calc(uplo, a, b)?;
123 Ok(unsafe { self.eigs.assume_init() })
124 }
125 }
126 };
127}
128impl_eigh_generalized_work_c!(c64, lapack_sys::zhegv_);
129impl_eigh_generalized_work_c!(c32, lapack_sys::chegv_);
130
131macro_rules! impl_eigh_generalized_work_r {
132 ($f:ty, $gv:path) => {
133 impl EighGeneralizedWorkImpl for EighGeneralizedWork<$f> {
134 type Elem = $f;
135
136 fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self> {
137 assert_eq!(layout.len(), layout.lda());
138 let n = layout.len();
139 let jobz = if calc_eigenvectors {
140 JobEv::All
141 } else {
142 JobEv::None
143 };
144 let mut eigs = vec_uninit(n as usize);
145 let mut info = 0;
146 let mut work_size = [Self::Elem::zero()];
147 unsafe {
148 $gv(
149 &1, jobz.as_ptr(),
151 UPLO::Upper.as_ptr(), &n,
153 std::ptr::null_mut(),
154 &n,
155 std::ptr::null_mut(),
156 &n,
157 AsPtr::as_mut_ptr(&mut eigs),
158 AsPtr::as_mut_ptr(&mut work_size),
159 &(-1),
160 &mut info,
161 );
162 }
163 info.as_lapack_result()?;
164 let lwork = work_size[0].to_usize().unwrap();
165 let work = vec_uninit(lwork);
166 Ok(EighGeneralizedWork {
167 n,
168 eigs,
169 jobz,
170 work,
171 rwork: None,
172 })
173 }
174
175 fn calc(
176 &mut self,
177 uplo: UPLO,
178 a: &mut [Self::Elem],
179 b: &mut [Self::Elem],
180 ) -> Result<&[<Self::Elem as Scalar>::Real]> {
181 let lwork = self.work.len().to_i32().unwrap();
182 let mut info = 0;
183 unsafe {
184 $gv(
185 &1, self.jobz.as_ptr(),
187 uplo.as_ptr(),
188 &self.n,
189 AsPtr::as_mut_ptr(a),
190 &self.n,
191 AsPtr::as_mut_ptr(b),
192 &self.n,
193 AsPtr::as_mut_ptr(&mut self.eigs),
194 AsPtr::as_mut_ptr(&mut self.work),
195 &lwork,
196 &mut info,
197 );
198 }
199 info.as_lapack_result()?;
200 Ok(unsafe { self.eigs.slice_assume_init_ref() })
201 }
202
203 fn eval(
204 mut self,
205 uplo: UPLO,
206 a: &mut [Self::Elem],
207 b: &mut [Self::Elem],
208 ) -> Result<Vec<<Self::Elem as Scalar>::Real>> {
209 let _eig = self.calc(uplo, a, b)?;
210 Ok(unsafe { self.eigs.assume_init() })
211 }
212 }
213 };
214}
215impl_eigh_generalized_work_r!(f64, lapack_sys::dsygv_);
216impl_eigh_generalized_work_r!(f32, lapack_sys::ssygv_);