1use crate::error::{LinalgError, Result};
6use crate::{cholesky::*, close_l2, eigh::*, norm::*, triangular::*};
7use cauchy::Scalar;
8use lax::Lapack;
9use ndarray::prelude::*;
10use ndarray::{Data, OwnedRepr, ScalarOperand};
11use num_traits::{Float, NumCast};
12
13#[derive(Debug, Clone)]
15pub enum Order {
16 Largest,
17 Smallest,
18}
19
20#[derive(Debug)]
28pub enum LobpcgResult<A> {
29 Ok(Array1<A>, Array2<A>, Vec<A>),
30 Err(Array1<A>, Array2<A>, Vec<A>, LinalgError),
31 NoResult(LinalgError),
32}
33
34fn sorted_eig<S: Data<Elem = A>, A: Scalar + Lapack>(
36 a: ArrayBase<S, Ix2>,
37 b: Option<ArrayBase<S, Ix2>>,
38 size: usize,
39 order: &Order,
40) -> Result<(Array1<A>, Array2<A>)> {
41 let n = a.len_of(Axis(0));
42
43 let (vals, vecs) = match b {
44 Some(b) => (a, b).eigh(UPLO::Upper).map(|x| (x.0, (x.1).0))?,
45 _ => a.eigh(UPLO::Upper)?,
46 };
47
48 Ok(match order {
49 Order::Largest => (
50 vals.slice_move(s![n-size..; -1]).mapv(Scalar::from_real),
51 vecs.slice_move(s![.., n-size..; -1]),
52 ),
53 Order::Smallest => (
54 vals.slice_move(s![..size]).mapv(Scalar::from_real),
55 vecs.slice_move(s![.., ..size]),
56 ),
57 })
58}
59
60fn ndarray_mask<A: Scalar>(matrix: ArrayView2<A>, mask: &[bool]) -> Array2<A> {
62 assert_eq!(mask.len(), matrix.ncols());
63
64 let indices = (0..mask.len())
65 .zip(mask.iter())
66 .filter(|(_, b)| **b)
67 .map(|(a, _)| a)
68 .collect::<Vec<usize>>();
69
70 matrix.select(Axis(1), &indices)
71}
72
73fn apply_constraints<A: Scalar + Lapack>(
77 mut v: ArrayViewMut<A, Ix2>,
78 cholesky_yy: &CholeskyFactorized<OwnedRepr<A>>,
79 y: ArrayView2<A>,
80) {
81 let gram_yv = y.t().dot(&v);
82
83 let u = gram_yv
84 .columns()
85 .into_iter()
86 .flat_map(|x| {
87 let res = cholesky_yy.solvec(&x).unwrap();
88
89 res.to_vec()
90 })
91 .collect::<Vec<A>>();
92
93 let rows = gram_yv.len_of(Axis(0));
94 let u = Array2::from_shape_vec((rows, u.len() / rows), u).unwrap();
95
96 v -= &(y.dot(&u));
97}
98
99fn orthonormalize<T: Scalar + Lapack>(v: Array2<T>) -> Result<(Array2<T>, Array2<T>)> {
103 let gram_vv = v.t().dot(&v);
104 let gram_vv_fac = gram_vv.cholesky(UPLO::Lower)?;
105
106 close_l2(
107 &gram_vv,
108 &gram_vv_fac.dot(&gram_vv_fac.t()),
109 NumCast::from(1e-5).unwrap(),
110 );
111
112 let v_t = v.reversed_axes();
113 let u = gram_vv_fac
114 .solve_triangular(UPLO::Lower, Diag::NonUnit, &v_t)?
115 .reversed_axes();
116
117 Ok((u, gram_vv_fac))
118}
119
120pub fn lobpcg<
141 A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
142 F: Fn(ArrayView2<A>) -> Array2<A>,
143 G: Fn(ArrayViewMut2<A>),
144>(
145 a: F,
146 mut x: Array2<A>,
147 m: G,
148 y: Option<Array2<A>>,
149 tol: f32,
150 maxiter: usize,
151 order: Order,
152) -> LobpcgResult<A> {
153 let (n, size_x) = (x.nrows(), x.ncols());
156 assert!(size_x <= n);
157
158 let mut iter = usize::min(n * 10, maxiter);
169 let tol = NumCast::from(tol).unwrap();
170
171 let cholesky_yy = y.as_ref().map(|y| {
173 let cholesky_yy = y.t().dot(y).factorizec(UPLO::Lower).unwrap();
174 apply_constraints(x.view_mut(), &cholesky_yy, y.view());
175 cholesky_yy
176 });
177
178 let (x, _) = match orthonormalize(x) {
180 Ok(x) => x,
181 Err(err) => return LobpcgResult::NoResult(err),
182 };
183
184 let ax = a(x.view());
186 let xax = x.t().dot(&ax);
187
188 let (mut lambda, eig_block) = match sorted_eig(xax.view(), None, size_x, &order) {
190 Ok(x) => x,
191 Err(err) => return LobpcgResult::NoResult(err),
192 };
193
194 let mut x = x.dot(&eig_block);
196 let mut ax = ax.dot(&eig_block);
197
198 let mut activemask = vec![true; size_x];
200
201 let mut residual_norms_history = Vec::new();
203 let mut best_result = None;
204
205 let mut previous_block_size = size_x;
206
207 let mut ident: Array2<A> = Array2::eye(size_x);
208 let ident0: Array2<A> = Array2::eye(size_x);
209 let two: A = NumCast::from(2.0).unwrap();
210
211 let mut previous_p_ap: Option<(Array2<A>, Array2<A>)> = None;
212 let mut explicit_gram_flag = true;
213
214 let final_norm = loop {
215 let lambda_diag = Array2::from_diag(&lambda);
217 let lambda_x = x.dot(&lambda_diag);
218
219 let r = &ax - &lambda_x;
221
222 let residual_norms = r
224 .columns()
225 .into_iter()
226 .map(|x| x.norm())
227 .collect::<Vec<A::Real>>();
228 residual_norms_history.push(residual_norms.clone());
229
230 let sum_rnorm: A::Real = residual_norms.iter().cloned().sum();
232 if best_result
233 .as_ref()
234 .map(|x: &(_, _, Vec<A::Real>)| x.2.iter().cloned().sum::<A::Real>() > sum_rnorm)
235 .unwrap_or(true)
236 {
237 best_result = Some((lambda.clone(), x.clone(), residual_norms.clone()));
238 }
239
240 activemask = residual_norms
242 .iter()
243 .zip(activemask.iter())
244 .map(|(x, a)| *x > tol && *a)
245 .collect();
246
247 let current_block_size = activemask.iter().filter(|x| **x).count();
249 if current_block_size != previous_block_size {
250 previous_block_size = current_block_size;
251 ident = Array2::eye(current_block_size);
252 }
253
254 if current_block_size == 0 || iter == 0 {
257 break Ok(residual_norms);
258 }
259
260 let mut active_block_r = ndarray_mask(r.view(), &activemask);
262 m(active_block_r.view_mut());
264 if let (Some(ref y), Some(ref cholesky_yy)) = (&y, &cholesky_yy) {
266 apply_constraints(active_block_r.view_mut(), cholesky_yy, y.view());
267 }
268 active_block_r -= &x.dot(&x.t().dot(&active_block_r));
270
271 let (r, _) = match orthonormalize(active_block_r) {
272 Ok(x) => x,
273 Err(err) => break Err(err),
274 };
275
276 let ar = a(r.view());
277
278 let max_rnorm_float = if A::epsilon() > NumCast::from(1e-8).unwrap() {
280 NumCast::from(1.0).unwrap()
281 } else {
282 NumCast::from(1.0e-8).unwrap()
283 };
284
285 let max_norm = residual_norms
287 .into_iter()
288 .fold(A::Real::neg_infinity(), A::Real::max);
289 explicit_gram_flag = max_norm <= max_rnorm_float || explicit_gram_flag;
290
291 let xar = x.t().dot(&ar);
293 let mut rar = r.t().dot(&ar);
294
295 let (xax, xx, rr, xr) = if explicit_gram_flag {
299 rar = (&rar + &rar.t()) / two;
300 let xax = x.t().dot(&ax);
301
302 (
303 (&xax + &xax.t()) / two,
304 x.t().dot(&x),
305 r.t().dot(&r),
306 x.t().dot(&r),
307 )
308 } else {
309 (
310 lambda_diag,
311 ident0.clone(),
312 ident.clone(),
313 Array2::zeros((size_x, current_block_size)),
314 )
315 };
316
317 let mut p_ap = previous_p_ap
319 .as_ref()
320 .and_then(|(p, ap)| {
321 let active_p = ndarray_mask(p.view(), &activemask);
322 let active_ap = ndarray_mask(ap.view(), &activemask);
323
324 orthonormalize(active_p).map(|x| (active_ap, x)).ok()
325 })
326 .and_then(|(active_ap, (active_p, p_r))| {
327 let active_ap = active_ap.reversed_axes();
329 p_r.solve_triangular(UPLO::Lower, Diag::NonUnit, &active_ap)
330 .map(|active_ap| (active_p, active_ap.reversed_axes()))
331 .ok()
332 });
333
334 let result = p_ap
339 .as_ref()
340 .ok_or(LinalgError::Lapack(
341 lax::error::Error::LapackComputationalFailure { return_code: 1 },
342 ))
343 .and_then(|(active_p, active_ap)| {
344 let xap = x.t().dot(active_ap);
345 let rap = r.t().dot(active_ap);
346 let pap = active_p.t().dot(active_ap);
347 let xp = x.t().dot(active_p);
348 let rp = r.t().dot(active_p);
349 let (pap, pp) = if explicit_gram_flag {
350 ((&pap + &pap.t()) / two, active_p.t().dot(active_p))
351 } else {
352 (pap, ident.clone())
353 };
354
355 sorted_eig(
356 concatenate![
357 Axis(0),
358 concatenate![Axis(1), xax, xar, xap],
359 concatenate![Axis(1), xar.t(), rar, rap],
360 concatenate![Axis(1), xap.t(), rap.t(), pap]
361 ],
362 Some(concatenate![
363 Axis(0),
364 concatenate![Axis(1), xx, xr, xp],
365 concatenate![Axis(1), xr.t(), rr, rp],
366 concatenate![Axis(1), xp.t(), rp.t(), pp]
367 ]),
368 size_x,
369 &order,
370 )
371 })
372 .or_else(|_| {
373 p_ap = None;
374
375 sorted_eig(
376 concatenate![
377 Axis(0),
378 concatenate![Axis(1), xax, xar],
379 concatenate![Axis(1), xar.t(), rar]
380 ],
381 Some(concatenate![
382 Axis(0),
383 concatenate![Axis(1), xx, xr],
384 concatenate![Axis(1), xr.t(), rr]
385 ]),
386 size_x,
387 &order,
388 )
389 });
390
391 let eig_vecs;
393 match result {
394 Ok((x, y)) => {
395 lambda = x;
396 eig_vecs = y;
397 }
398 Err(x) => break Err(x),
399 }
400
401 let (p, ap, tau) = if let Some((active_p, active_ap)) = p_ap {
403 let tau = eig_vecs.slice(s![..size_x, ..]);
405 let alpha = eig_vecs.slice(s![size_x..size_x + current_block_size, ..]);
407 let gamma = eig_vecs.slice(s![size_x + current_block_size.., ..]);
409
410 let updated_p = r.dot(&alpha) + active_p.dot(&gamma);
412 let updated_ap = ar.dot(&alpha) + active_ap.dot(&gamma);
413
414 (updated_p, updated_ap, tau)
415 } else {
416 let tau = eig_vecs.slice(s![..size_x, ..]);
418 let alpha = eig_vecs.slice(s![size_x.., ..]);
420
421 let updated_p = r.dot(&alpha);
423 let updated_ap = ar.dot(&alpha);
424
425 (updated_p, updated_ap, tau)
426 };
427
428 x = x.dot(&tau) + &p;
430 ax = ax.dot(&tau) + ≈
431
432 previous_p_ap = Some((p, ap));
433
434 iter -= 1;
435 };
436
437 let (vals, vecs, rnorm) = best_result.unwrap();
439 let rnorm = rnorm.into_iter().map(Scalar::from_real).collect();
440
441 match final_norm {
442 Ok(_) => LobpcgResult::Ok(vals, vecs, rnorm),
443 Err(err) => LobpcgResult::Err(vals, vecs, rnorm, err),
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::lobpcg;
450 use super::ndarray_mask;
451 use super::orthonormalize;
452 use super::sorted_eig;
453 use super::LobpcgResult;
454 use super::Order;
455 use crate::close_l2;
456 use crate::generate;
457 use crate::qr::*;
458 use ndarray::prelude::*;
459
460 #[test]
462 fn test_sorted_eigen() {
463 let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
464 let matrix: Array2<f64> = generate::random_using((10, 10), &mut rng) * 10.0;
465 let matrix = matrix.t().dot(&matrix);
466
467 let (vals, vecs) = sorted_eig(matrix.view(), None, 10, &Order::Largest).unwrap();
469
470 let diag = Array2::from_diag(&vals);
472 let rec = (vecs.dot(&diag)).dot(&vecs.t());
473
474 close_l2(&matrix, &rec, 1e-5);
475 }
476
477 #[test]
479 fn test_masking() {
480 let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
481 let matrix: Array2<f64> = generate::random_using((10, 5), &mut rng) * 10.0;
482 let masked_matrix = ndarray_mask(matrix.view(), &[true, true, false, true, false]);
483 close_l2(
484 &masked_matrix.slice(s![.., 2]),
485 &matrix.slice(s![.., 3]),
486 1e-12,
487 );
488 }
489
490 #[test]
492 fn test_orthonormalize() {
493 let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
494 let matrix: Array2<f64> = generate::random_using((10, 10), &mut rng) * 10.0;
495
496 let (n, l) = orthonormalize(matrix.clone()).unwrap();
497
498 let identity = n.dot(&n.t());
500 close_l2(&identity, &Array2::eye(10), 1e-2);
501
502 let (_, r) = matrix.qr().unwrap();
504 close_l2(&r.mapv(|x| x.abs()), &l.t().mapv(|x| x.abs()), 1e-2);
505 }
506
507 fn assert_symmetric(a: &Array2<f64>) {
508 close_l2(a, &a.t(), 1e-5);
509 }
510
511 fn check_eigenvalues(a: &Array2<f64>, order: Order, num: usize, ground_truth_eigvals: &[f64]) {
512 assert_symmetric(a);
513
514 let n = a.len_of(Axis(0));
515 let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
516 let x: Array2<f64> = generate::random_using((n, num), &mut rng);
517
518 let result = lobpcg(|y| a.dot(&y), x, |_| {}, None, 1e-5, n * 2, order);
519 match result {
520 LobpcgResult::Ok(vals, _, r_norms) | LobpcgResult::Err(vals, _, r_norms, _) => {
521 for (i, norm) in r_norms.into_iter().enumerate() {
523 if norm > 1e-5 {
524 println!("==== Assertion Failed ====");
525 println!("The {}th eigenvalue estimation did not converge!", i);
526 panic!("Too large deviation of residual norm: {} > 0.01", norm);
527 }
528 }
529
530 if ground_truth_eigvals.len() == num {
532 close_l2(
533 &Array1::from(ground_truth_eigvals.to_vec()),
534 &vals,
535 num as f64 * 5e-4,
536 )
537 }
538 }
539 LobpcgResult::NoResult(err) => panic!("Did not converge: {:?}", err),
540 }
541 }
542
543 #[test]
545 fn test_eigsolver_diag() {
546 let diag = arr1(&[
547 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.,
548 20.,
549 ]);
550 let a = Array2::from_diag(&diag);
551
552 check_eigenvalues(&a, Order::Largest, 3, &[20., 19., 18.]);
553 check_eigenvalues(&a, Order::Smallest, 3, &[1., 2., 3.]);
554 }
555
556 #[test]
558 fn test_eigsolver_constructed() {
559 let n = 50;
560 let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
561 let tmp = generate::random_using((n, n), &mut rng);
562 let (v, _) = orthonormalize(tmp).unwrap();
564
565 let t = Array2::from_diag(&Array1::linspace(n as f64, -(n as f64), n));
567 let a = v.dot(&t.dot(&v.t()));
568
569 check_eigenvalues(&a, Order::Largest, 5, &[50.0, 48.0, 46.0, 44.0, 42.0]);
571 check_eigenvalues(&a, Order::Smallest, 5, &[-50.0, -48.0, -46.0, -44.0, -42.0]);
572 }
573
574 #[test]
575 fn test_eigsolver_constrained() {
576 let diag = arr1(&[1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]);
577 let a = Array2::from_diag(&diag);
578 let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
579 let x: Array2<f64> = generate::random_using((10, 1), &mut rng);
580 let y: Array2<f64> = arr2(&[
581 [1.0, 0., 0., 0., 0., 0., 0., 0., 0., 0.],
582 [0., 1.0, 0., 0., 0., 0., 0., 0., 0., 0.],
583 ])
584 .reversed_axes();
585
586 let result = lobpcg(
587 |y| a.dot(&y),
588 x,
589 |_| {},
590 Some(y),
591 1e-10,
592 50,
593 Order::Smallest,
594 );
595 match result {
596 LobpcgResult::Ok(vals, vecs, r_norms) | LobpcgResult::Err(vals, vecs, r_norms, _) => {
597 for (i, norm) in r_norms.into_iter().enumerate() {
599 if norm > 0.01 {
600 println!("==== Assertion Failed ====");
601 println!("The {}th eigenvalue estimation did not converge!", i);
602 panic!("Too large deviation of residual norm: {} > 0.01", norm);
603 }
604 }
605
606 close_l2(&vals, &Array1::from(vec![3.0]), 1e-10);
608 close_l2(
609 &vecs.column(0).mapv(|x| x.abs()),
610 &arr1(&[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]),
611 1e-5,
612 );
613 }
614 LobpcgResult::NoResult(err) => panic!("Did not converge: {:?}", err),
615 }
616 }
617}