ndarray_linalg/lobpcg/
eig.rs

1use super::lobpcg::{lobpcg, LobpcgResult, Order};
2use crate::{generate, Scalar};
3use lax::Lapack;
4
5///! Implements truncated eigenvalue decomposition
6///
7use ndarray::prelude::*;
8use ndarray::stack;
9use ndarray::ScalarOperand;
10use num_traits::{Float, NumCast};
11
12/// Truncated eigenproblem solver
13///
14/// This struct wraps the LOBPCG algorithm and provides convenient builder-pattern access to
15/// parameter like maximal iteration, precision and constraint matrix. Furthermore it allows
16/// conversion into a iterative solver where each iteration step yields a new eigenvalue/vector
17/// pair.
18pub struct TruncatedEig<A: Scalar> {
19    order: Order,
20    problem: Array2<A>,
21    pub constraints: Option<Array2<A>>,
22    preconditioner: Option<Array2<A>>,
23    precision: f32,
24    maxiter: usize,
25}
26
27impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> TruncatedEig<A> {
28    pub fn new(problem: Array2<A>, order: Order) -> TruncatedEig<A> {
29        TruncatedEig {
30            precision: 1e-5,
31            maxiter: problem.len_of(Axis(0)) * 2,
32            preconditioner: None,
33            constraints: None,
34            order,
35            problem,
36        }
37    }
38
39    pub fn precision(mut self, precision: f32) -> Self {
40        self.precision = precision;
41
42        self
43    }
44
45    pub fn maxiter(mut self, maxiter: usize) -> Self {
46        self.maxiter = maxiter;
47
48        self
49    }
50
51    pub fn orthogonal_to(mut self, constraints: Array2<A>) -> Self {
52        self.constraints = Some(constraints);
53
54        self
55    }
56
57    pub fn precondition_with(mut self, preconditioner: Array2<A>) -> Self {
58        self.preconditioner = Some(preconditioner);
59
60        self
61    }
62
63    // calculate the eigenvalues decompose
64    pub fn decompose(&self, num: usize) -> LobpcgResult<A> {
65        let x: Array2<f64> = generate::random((self.problem.len_of(Axis(0)), num));
66        let x = x.mapv(|x| NumCast::from(x).unwrap());
67
68        if let Some(ref preconditioner) = self.preconditioner {
69            lobpcg(
70                |y| self.problem.dot(&y),
71                x,
72                |mut y| y.assign(&preconditioner.dot(&y)),
73                self.constraints.clone(),
74                self.precision,
75                self.maxiter,
76                self.order.clone(),
77            )
78        } else {
79            lobpcg(
80                |y| self.problem.dot(&y),
81                x,
82                |_| {},
83                self.constraints.clone(),
84                self.precision,
85                self.maxiter,
86                self.order.clone(),
87            )
88        }
89    }
90}
91
92impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> IntoIterator
93    for TruncatedEig<A>
94{
95    type Item = (Array1<A>, Array2<A>);
96    type IntoIter = TruncatedEigIterator<A>;
97
98    fn into_iter(self) -> TruncatedEigIterator<A> {
99        TruncatedEigIterator {
100            step_size: 1,
101            remaining: self.problem.len_of(Axis(0)),
102            eig: self,
103        }
104    }
105}
106
107/// Truncate eigenproblem iterator
108///
109/// This wraps a truncated eigenproblem and provides an iterator where each step yields a new
110/// eigenvalue/vector pair. Useful for generating pairs until a certain condition is met.
111pub struct TruncatedEigIterator<A: Scalar> {
112    step_size: usize,
113    remaining: usize,
114    eig: TruncatedEig<A>,
115}
116
117impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> Iterator
118    for TruncatedEigIterator<A>
119{
120    type Item = (Array1<A>, Array2<A>);
121
122    fn next(&mut self) -> Option<Self::Item> {
123        if self.remaining == 0 {
124            return None;
125        }
126
127        let step_size = usize::min(self.step_size, self.remaining);
128        let res = self.eig.decompose(step_size);
129
130        match res {
131            LobpcgResult::Ok(vals, vecs, norms) | LobpcgResult::Err(vals, vecs, norms, _) => {
132                // abort if any eigenproblem did not converge
133                for r_norm in norms {
134                    if r_norm > NumCast::from(0.1).unwrap() {
135                        return None;
136                    }
137                }
138
139                // add the new eigenvector to the internal constrain matrix
140                let new_constraints = if let Some(ref constraints) = self.eig.constraints {
141                    let eigvecs_arr: Vec<_> = constraints
142                        .columns()
143                        .into_iter()
144                        .chain(vecs.columns().into_iter())
145                        .collect();
146
147                    stack(Axis(1), &eigvecs_arr).unwrap()
148                } else {
149                    vecs.clone()
150                };
151
152                self.eig.constraints = Some(new_constraints);
153                self.remaining -= step_size;
154
155                Some((vals, vecs))
156            }
157            LobpcgResult::NoResult(_) => None,
158        }
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::Order;
165    use super::TruncatedEig;
166    use ndarray::{arr1, Array2};
167
168    #[test]
169    fn test_truncated_eig() {
170        let diag = arr1(&[
171            1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.,
172            20.,
173        ]);
174        let a = Array2::from_diag(&diag);
175
176        let teig = TruncatedEig::new(a, Order::Largest)
177            .precision(1e-5)
178            .maxiter(500);
179
180        let res = teig
181            .into_iter()
182            .take(3)
183            .flat_map(|x| x.0.to_vec())
184            .collect::<Vec<_>>();
185        let ground_truth = vec![20., 19., 18.];
186
187        assert!(
188            ground_truth
189                .into_iter()
190                .zip(res.into_iter())
191                .map(|(x, y)| (x - y) * (x - y))
192                .sum::<f64>()
193                < 0.01
194        );
195    }
196}