1use super::*;
11use crate::{error::*, layout::MatrixLayout};
12use cauchy::*;
13use num_traits::{ToPrimitive, Zero};
14
15pub struct EighWork<T: Scalar> {
16 pub n: i32,
17 pub jobz: JobEv,
18 pub eigs: Vec<MaybeUninit<T::Real>>,
19 pub work: Vec<MaybeUninit<T>>,
20 pub rwork: Option<Vec<MaybeUninit<T::Real>>>,
21}
22
23pub trait EighWorkImpl: Sized {
24 type Elem: Scalar;
25 fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self>;
26 fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem])
27 -> Result<&[<Self::Elem as Scalar>::Real]>;
28 fn eval(self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<Vec<<Self::Elem as Scalar>::Real>>;
29}
30
31macro_rules! impl_eigh_work_c {
32 ($c:ty, $ev:path) => {
33 impl EighWorkImpl for EighWork<$c> {
34 type Elem = $c;
35
36 fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self> {
37 assert_eq!(layout.len(), layout.lda());
38 let n = layout.len();
39 let jobz = if calc_eigenvectors {
40 JobEv::All
41 } else {
42 JobEv::None
43 };
44 let mut eigs = vec_uninit(n as usize);
45 let mut rwork = vec_uninit(3 * n as usize - 2 as usize);
46 let mut info = 0;
47 let mut work_size = [Self::Elem::zero()];
48 unsafe {
49 $ev(
50 jobz.as_ptr(),
51 UPLO::Upper.as_ptr(), &n,
53 std::ptr::null_mut(),
54 &n,
55 AsPtr::as_mut_ptr(&mut eigs),
56 AsPtr::as_mut_ptr(&mut work_size),
57 &(-1),
58 AsPtr::as_mut_ptr(&mut rwork),
59 &mut info,
60 );
61 }
62 info.as_lapack_result()?;
63 let lwork = work_size[0].to_usize().unwrap();
64 let work = vec_uninit(lwork);
65 Ok(EighWork {
66 n,
67 eigs,
68 jobz,
69 work,
70 rwork: Some(rwork),
71 })
72 }
73
74 fn calc(
75 &mut self,
76 uplo: UPLO,
77 a: &mut [Self::Elem],
78 ) -> Result<&[<Self::Elem as Scalar>::Real]> {
79 let lwork = self.work.len().to_i32().unwrap();
80 let mut info = 0;
81 unsafe {
82 $ev(
83 self.jobz.as_ptr(),
84 uplo.as_ptr(),
85 &self.n,
86 AsPtr::as_mut_ptr(a),
87 &self.n,
88 AsPtr::as_mut_ptr(&mut self.eigs),
89 AsPtr::as_mut_ptr(&mut self.work),
90 &lwork,
91 AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()),
92 &mut info,
93 );
94 }
95 info.as_lapack_result()?;
96 Ok(unsafe { self.eigs.slice_assume_init_ref() })
97 }
98
99 fn eval(
100 mut self,
101 uplo: UPLO,
102 a: &mut [Self::Elem],
103 ) -> Result<Vec<<Self::Elem as Scalar>::Real>> {
104 let _eig = self.calc(uplo, a)?;
105 Ok(unsafe { self.eigs.assume_init() })
106 }
107 }
108 };
109}
110impl_eigh_work_c!(c64, lapack_sys::zheev_);
111impl_eigh_work_c!(c32, lapack_sys::cheev_);
112
113macro_rules! impl_eigh_work_r {
114 ($f:ty, $ev:path) => {
115 impl EighWorkImpl for EighWork<$f> {
116 type Elem = $f;
117
118 fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self> {
119 assert_eq!(layout.len(), layout.lda());
120 let n = layout.len();
121 let jobz = if calc_eigenvectors {
122 JobEv::All
123 } else {
124 JobEv::None
125 };
126 let mut eigs = vec_uninit(n as usize);
127 let mut info = 0;
128 let mut work_size = [Self::Elem::zero()];
129 unsafe {
130 $ev(
131 jobz.as_ptr(),
132 UPLO::Upper.as_ptr(), &n,
134 std::ptr::null_mut(),
135 &n,
136 AsPtr::as_mut_ptr(&mut eigs),
137 AsPtr::as_mut_ptr(&mut work_size),
138 &(-1),
139 &mut info,
140 );
141 }
142 info.as_lapack_result()?;
143 let lwork = work_size[0].to_usize().unwrap();
144 let work = vec_uninit(lwork);
145 Ok(EighWork {
146 n,
147 eigs,
148 jobz,
149 work,
150 rwork: None,
151 })
152 }
153
154 fn calc(
155 &mut self,
156 uplo: UPLO,
157 a: &mut [Self::Elem],
158 ) -> Result<&[<Self::Elem as Scalar>::Real]> {
159 let lwork = self.work.len().to_i32().unwrap();
160 let mut info = 0;
161 unsafe {
162 $ev(
163 self.jobz.as_ptr(),
164 uplo.as_ptr(),
165 &self.n,
166 AsPtr::as_mut_ptr(a),
167 &self.n,
168 AsPtr::as_mut_ptr(&mut self.eigs),
169 AsPtr::as_mut_ptr(&mut self.work),
170 &lwork,
171 &mut info,
172 );
173 }
174 info.as_lapack_result()?;
175 Ok(unsafe { self.eigs.slice_assume_init_ref() })
176 }
177
178 fn eval(
179 mut self,
180 uplo: UPLO,
181 a: &mut [Self::Elem],
182 ) -> Result<Vec<<Self::Elem as Scalar>::Real>> {
183 let _eig = self.calc(uplo, a)?;
184 Ok(unsafe { self.eigs.assume_init() })
185 }
186 }
187 };
188}
189impl_eigh_work_r!(f64, lapack_sys::dsyev_);
190impl_eigh_work_r!(f32, lapack_sys::ssyev_);