ndarray_linalg/lobpcg/
eig.rs1use super::lobpcg::{lobpcg, LobpcgResult, Order};
2use crate::{generate, Scalar};
3use lax::Lapack;
4
5use ndarray::prelude::*;
8use ndarray::stack;
9use ndarray::ScalarOperand;
10use num_traits::{Float, NumCast};
11
12pub struct TruncatedEig<A: Scalar> {
19 order: Order,
20 problem: Array2<A>,
21 pub constraints: Option<Array2<A>>,
22 preconditioner: Option<Array2<A>>,
23 precision: f32,
24 maxiter: usize,
25}
26
27impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> TruncatedEig<A> {
28 pub fn new(problem: Array2<A>, order: Order) -> TruncatedEig<A> {
29 TruncatedEig {
30 precision: 1e-5,
31 maxiter: problem.len_of(Axis(0)) * 2,
32 preconditioner: None,
33 constraints: None,
34 order,
35 problem,
36 }
37 }
38
39 pub fn precision(mut self, precision: f32) -> Self {
40 self.precision = precision;
41
42 self
43 }
44
45 pub fn maxiter(mut self, maxiter: usize) -> Self {
46 self.maxiter = maxiter;
47
48 self
49 }
50
51 pub fn orthogonal_to(mut self, constraints: Array2<A>) -> Self {
52 self.constraints = Some(constraints);
53
54 self
55 }
56
57 pub fn precondition_with(mut self, preconditioner: Array2<A>) -> Self {
58 self.preconditioner = Some(preconditioner);
59
60 self
61 }
62
63 pub fn decompose(&self, num: usize) -> LobpcgResult<A> {
65 let x: Array2<f64> = generate::random((self.problem.len_of(Axis(0)), num));
66 let x = x.mapv(|x| NumCast::from(x).unwrap());
67
68 if let Some(ref preconditioner) = self.preconditioner {
69 lobpcg(
70 |y| self.problem.dot(&y),
71 x,
72 |mut y| y.assign(&preconditioner.dot(&y)),
73 self.constraints.clone(),
74 self.precision,
75 self.maxiter,
76 self.order.clone(),
77 )
78 } else {
79 lobpcg(
80 |y| self.problem.dot(&y),
81 x,
82 |_| {},
83 self.constraints.clone(),
84 self.precision,
85 self.maxiter,
86 self.order.clone(),
87 )
88 }
89 }
90}
91
92impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> IntoIterator
93 for TruncatedEig<A>
94{
95 type Item = (Array1<A>, Array2<A>);
96 type IntoIter = TruncatedEigIterator<A>;
97
98 fn into_iter(self) -> TruncatedEigIterator<A> {
99 TruncatedEigIterator {
100 step_size: 1,
101 remaining: self.problem.len_of(Axis(0)),
102 eig: self,
103 }
104 }
105}
106
107pub struct TruncatedEigIterator<A: Scalar> {
112 step_size: usize,
113 remaining: usize,
114 eig: TruncatedEig<A>,
115}
116
117impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> Iterator
118 for TruncatedEigIterator<A>
119{
120 type Item = (Array1<A>, Array2<A>);
121
122 fn next(&mut self) -> Option<Self::Item> {
123 if self.remaining == 0 {
124 return None;
125 }
126
127 let step_size = usize::min(self.step_size, self.remaining);
128 let res = self.eig.decompose(step_size);
129
130 match res {
131 LobpcgResult::Ok(vals, vecs, norms) | LobpcgResult::Err(vals, vecs, norms, _) => {
132 for r_norm in norms {
134 if r_norm > NumCast::from(0.1).unwrap() {
135 return None;
136 }
137 }
138
139 let new_constraints = if let Some(ref constraints) = self.eig.constraints {
141 let eigvecs_arr: Vec<_> = constraints
142 .columns()
143 .into_iter()
144 .chain(vecs.columns().into_iter())
145 .collect();
146
147 stack(Axis(1), &eigvecs_arr).unwrap()
148 } else {
149 vecs.clone()
150 };
151
152 self.eig.constraints = Some(new_constraints);
153 self.remaining -= step_size;
154
155 Some((vals, vecs))
156 }
157 LobpcgResult::NoResult(_) => None,
158 }
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::Order;
165 use super::TruncatedEig;
166 use ndarray::{arr1, Array2};
167
168 #[test]
169 fn test_truncated_eig() {
170 let diag = arr1(&[
171 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.,
172 20.,
173 ]);
174 let a = Array2::from_diag(&diag);
175
176 let teig = TruncatedEig::new(a, Order::Largest)
177 .precision(1e-5)
178 .maxiter(500);
179
180 let res = teig
181 .into_iter()
182 .take(3)
183 .flat_map(|x| x.0.to_vec())
184 .collect::<Vec<_>>();
185 let ground_truth = vec![20., 19., 18.];
186
187 assert!(
188 ground_truth
189 .into_iter()
190 .zip(res.into_iter())
191 .map(|(x, y)| (x - y) * (x - y))
192 .sum::<f64>()
193 < 0.01
194 );
195 }
196}