ndarray_linalg/lobpcg/
svd.rs

1///! Truncated singular value decomposition
2///!
3///! This module computes the k largest/smallest singular values/vectors for a dense matrix.
4use super::lobpcg::{lobpcg, LobpcgResult, Order};
5use crate::error::Result;
6use crate::generate;
7use cauchy::Scalar;
8use lax::Lapack;
9use ndarray::prelude::*;
10use ndarray::ScalarOperand;
11use num_traits::{Float, NumCast};
12use std::ops::DivAssign;
13
14/// The result of a eigenvalue decomposition, not yet transformed into singular values/vectors
15///
16/// Provides methods for either calculating just the singular values with reduced cost or the
17/// vectors with additional cost of matrix multiplication.
18#[derive(Debug)]
19pub struct TruncatedSvdResult<A> {
20    eigvals: Array1<A>,
21    eigvecs: Array2<A>,
22    problem: Array2<A>,
23    ngm: bool,
24}
25
26impl<A: Float + PartialOrd + DivAssign<A> + 'static + MagnitudeCorrection> TruncatedSvdResult<A> {
27    /// Returns singular values ordered by magnitude with indices.
28    fn singular_values_with_indices(&self) -> (Array1<A>, Vec<usize>) {
29        // numerate eigenvalues
30        let mut a = self.eigvals.iter().enumerate().collect::<Vec<_>>();
31
32        // sort by magnitude
33        a.sort_by(|(_, x), (_, y)| x.partial_cmp(y).unwrap().reverse());
34
35        // calculate cut-off magnitude (borrowed from scipy)
36        let cutoff = A::epsilon() * // float precision
37                     A::correction() * // correction term (see trait below)
38                     *a[0].1; // max eigenvalue
39
40        // filter low singular values away
41        let (values, indices): (Vec<A>, Vec<usize>) = a
42            .into_iter()
43            .filter(|(_, x)| *x > &cutoff)
44            .map(|(a, b)| (b.sqrt(), a))
45            .unzip();
46
47        (Array1::from(values), indices)
48    }
49
50    /// Returns singular values ordered by magnitude
51    pub fn values(&self) -> Array1<A> {
52        let (values, _) = self.singular_values_with_indices();
53
54        values
55    }
56
57    /// Returns singular values, left-singular vectors and right-singular vectors
58    pub fn values_vectors(&self) -> (Array2<A>, Array1<A>, Array2<A>) {
59        let (values, indices) = self.singular_values_with_indices();
60
61        // branch n > m (for A is [n x m])
62        let (u, v) = if self.ngm {
63            let vlarge = self.eigvecs.select(Axis(1), &indices);
64            let mut ularge = self.problem.dot(&vlarge);
65
66            ularge
67                .columns_mut()
68                .into_iter()
69                .zip(values.iter())
70                .for_each(|(mut a, b)| a.mapv_inplace(|x| x / *b));
71
72            (ularge, vlarge)
73        } else {
74            let ularge = self.eigvecs.select(Axis(1), &indices);
75
76            let mut vlarge = self.problem.t().dot(&ularge);
77            vlarge
78                .columns_mut()
79                .into_iter()
80                .zip(values.iter())
81                .for_each(|(mut a, b)| a.mapv_inplace(|x| x / *b));
82
83            (ularge, vlarge)
84        };
85
86        (u, values, v.reversed_axes())
87    }
88}
89
90/// Truncated singular value decomposition
91///
92/// Wraps the LOBPCG algorithm and provides convenient builder-pattern access to
93/// parameter like maximal iteration, precision and constraint matrix.
94pub struct TruncatedSvd<A: Scalar> {
95    order: Order,
96    problem: Array2<A>,
97    precision: f32,
98    maxiter: usize,
99}
100
101impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> TruncatedSvd<A> {
102    pub fn new(problem: Array2<A>, order: Order) -> TruncatedSvd<A> {
103        TruncatedSvd {
104            precision: 1e-5,
105            maxiter: problem.len_of(Axis(0)) * 2,
106            order,
107            problem,
108        }
109    }
110
111    pub fn precision(mut self, precision: f32) -> Self {
112        self.precision = precision;
113
114        self
115    }
116
117    pub fn maxiter(mut self, maxiter: usize) -> Self {
118        self.maxiter = maxiter;
119
120        self
121    }
122
123    // calculate the eigenvalue decomposition
124    pub fn decompose(self, num: usize) -> Result<TruncatedSvdResult<A>> {
125        if num < 1 {
126            panic!("The number of singular values to compute should be larger than zero!");
127        }
128
129        let (n, m) = (self.problem.nrows(), self.problem.ncols());
130
131        // generate initial matrix
132        let x: Array2<f32> = generate::random((usize::min(n, m), num));
133        let x = x.mapv(|x| NumCast::from(x).unwrap());
134
135        // square precision because the SVD squares the eigenvalue as well
136        let precision = self.precision * self.precision;
137
138        // use problem definition with less operations required
139        let res = if n > m {
140            lobpcg(
141                |y| self.problem.t().dot(&self.problem.dot(&y)),
142                x,
143                |_| {},
144                None,
145                precision,
146                self.maxiter,
147                self.order.clone(),
148            )
149        } else {
150            lobpcg(
151                |y| self.problem.dot(&self.problem.t().dot(&y)),
152                x,
153                |_| {},
154                None,
155                precision,
156                self.maxiter,
157                self.order.clone(),
158            )
159        };
160
161        // convert into TruncatedSvdResult
162        match res {
163            LobpcgResult::Ok(vals, vecs, _) | LobpcgResult::Err(vals, vecs, _, _) => {
164                Ok(TruncatedSvdResult {
165                    problem: self.problem,
166                    eigvals: vals,
167                    eigvecs: vecs,
168                    ngm: n > m,
169                })
170            }
171            LobpcgResult::NoResult(err) => Err(err),
172        }
173    }
174}
175
176pub trait MagnitudeCorrection {
177    fn correction() -> Self;
178}
179
180impl MagnitudeCorrection for f32 {
181    fn correction() -> Self {
182        1.0e3
183    }
184}
185
186impl MagnitudeCorrection for f64 {
187    fn correction() -> Self {
188        1.0e6
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::Order;
195    use super::TruncatedSvd;
196    use crate::{close_l2, generate};
197
198    use ndarray::{arr1, arr2, Array2};
199
200    #[test]
201    fn test_truncated_svd() {
202        let a = arr2(&[[3., 2., 2.], [2., 3., -2.]]);
203
204        let res = TruncatedSvd::new(a, Order::Largest)
205            .precision(1e-5)
206            .maxiter(10)
207            .decompose(2)
208            .unwrap();
209
210        let (_, sigma, _) = res.values_vectors();
211
212        close_l2(&sigma, &arr1(&[5.0, 3.0]), 1e-5);
213    }
214
215    #[test]
216    fn test_truncated_svd_random() {
217        let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
218        let a: Array2<f64> = generate::random_using((50, 10), &mut rng);
219
220        let res = TruncatedSvd::new(a.clone(), Order::Largest)
221            .precision(1e-5)
222            .maxiter(10)
223            .decompose(10)
224            .unwrap();
225
226        let (u, sigma, v_t) = res.values_vectors();
227        let reconstructed = u.dot(&Array2::from_diag(&sigma).dot(&v_t));
228
229        close_l2(&a, &reconstructed, 1e-5);
230    }
231}