1use super::lobpcg::{lobpcg, LobpcgResult, Order};
5use crate::error::Result;
6use crate::generate;
7use cauchy::Scalar;
8use lax::Lapack;
9use ndarray::prelude::*;
10use ndarray::ScalarOperand;
11use num_traits::{Float, NumCast};
12use std::ops::DivAssign;
13
14#[derive(Debug)]
19pub struct TruncatedSvdResult<A> {
20 eigvals: Array1<A>,
21 eigvecs: Array2<A>,
22 problem: Array2<A>,
23 ngm: bool,
24}
25
26impl<A: Float + PartialOrd + DivAssign<A> + 'static + MagnitudeCorrection> TruncatedSvdResult<A> {
27 fn singular_values_with_indices(&self) -> (Array1<A>, Vec<usize>) {
29 let mut a = self.eigvals.iter().enumerate().collect::<Vec<_>>();
31
32 a.sort_by(|(_, x), (_, y)| x.partial_cmp(y).unwrap().reverse());
34
35 let cutoff = A::epsilon() * A::correction() * *a[0].1; let (values, indices): (Vec<A>, Vec<usize>) = a
42 .into_iter()
43 .filter(|(_, x)| *x > &cutoff)
44 .map(|(a, b)| (b.sqrt(), a))
45 .unzip();
46
47 (Array1::from(values), indices)
48 }
49
50 pub fn values(&self) -> Array1<A> {
52 let (values, _) = self.singular_values_with_indices();
53
54 values
55 }
56
57 pub fn values_vectors(&self) -> (Array2<A>, Array1<A>, Array2<A>) {
59 let (values, indices) = self.singular_values_with_indices();
60
61 let (u, v) = if self.ngm {
63 let vlarge = self.eigvecs.select(Axis(1), &indices);
64 let mut ularge = self.problem.dot(&vlarge);
65
66 ularge
67 .columns_mut()
68 .into_iter()
69 .zip(values.iter())
70 .for_each(|(mut a, b)| a.mapv_inplace(|x| x / *b));
71
72 (ularge, vlarge)
73 } else {
74 let ularge = self.eigvecs.select(Axis(1), &indices);
75
76 let mut vlarge = self.problem.t().dot(&ularge);
77 vlarge
78 .columns_mut()
79 .into_iter()
80 .zip(values.iter())
81 .for_each(|(mut a, b)| a.mapv_inplace(|x| x / *b));
82
83 (ularge, vlarge)
84 };
85
86 (u, values, v.reversed_axes())
87 }
88}
89
90pub struct TruncatedSvd<A: Scalar> {
95 order: Order,
96 problem: Array2<A>,
97 precision: f32,
98 maxiter: usize,
99}
100
101impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> TruncatedSvd<A> {
102 pub fn new(problem: Array2<A>, order: Order) -> TruncatedSvd<A> {
103 TruncatedSvd {
104 precision: 1e-5,
105 maxiter: problem.len_of(Axis(0)) * 2,
106 order,
107 problem,
108 }
109 }
110
111 pub fn precision(mut self, precision: f32) -> Self {
112 self.precision = precision;
113
114 self
115 }
116
117 pub fn maxiter(mut self, maxiter: usize) -> Self {
118 self.maxiter = maxiter;
119
120 self
121 }
122
123 pub fn decompose(self, num: usize) -> Result<TruncatedSvdResult<A>> {
125 if num < 1 {
126 panic!("The number of singular values to compute should be larger than zero!");
127 }
128
129 let (n, m) = (self.problem.nrows(), self.problem.ncols());
130
131 let x: Array2<f32> = generate::random((usize::min(n, m), num));
133 let x = x.mapv(|x| NumCast::from(x).unwrap());
134
135 let precision = self.precision * self.precision;
137
138 let res = if n > m {
140 lobpcg(
141 |y| self.problem.t().dot(&self.problem.dot(&y)),
142 x,
143 |_| {},
144 None,
145 precision,
146 self.maxiter,
147 self.order.clone(),
148 )
149 } else {
150 lobpcg(
151 |y| self.problem.dot(&self.problem.t().dot(&y)),
152 x,
153 |_| {},
154 None,
155 precision,
156 self.maxiter,
157 self.order.clone(),
158 )
159 };
160
161 match res {
163 LobpcgResult::Ok(vals, vecs, _) | LobpcgResult::Err(vals, vecs, _, _) => {
164 Ok(TruncatedSvdResult {
165 problem: self.problem,
166 eigvals: vals,
167 eigvecs: vecs,
168 ngm: n > m,
169 })
170 }
171 LobpcgResult::NoResult(err) => Err(err),
172 }
173 }
174}
175
176pub trait MagnitudeCorrection {
177 fn correction() -> Self;
178}
179
180impl MagnitudeCorrection for f32 {
181 fn correction() -> Self {
182 1.0e3
183 }
184}
185
186impl MagnitudeCorrection for f64 {
187 fn correction() -> Self {
188 1.0e6
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::Order;
195 use super::TruncatedSvd;
196 use crate::{close_l2, generate};
197
198 use ndarray::{arr1, arr2, Array2};
199
200 #[test]
201 fn test_truncated_svd() {
202 let a = arr2(&[[3., 2., 2.], [2., 3., -2.]]);
203
204 let res = TruncatedSvd::new(a, Order::Largest)
205 .precision(1e-5)
206 .maxiter(10)
207 .decompose(2)
208 .unwrap();
209
210 let (_, sigma, _) = res.values_vectors();
211
212 close_l2(&sigma, &arr1(&[5.0, 3.0]), 1e-5);
213 }
214
215 #[test]
216 fn test_truncated_svd_random() {
217 let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5);
218 let a: Array2<f64> = generate::random_using((50, 10), &mut rng);
219
220 let res = TruncatedSvd::new(a.clone(), Order::Largest)
221 .precision(1e-5)
222 .maxiter(10)
223 .decompose(10)
224 .unwrap();
225
226 let (u, sigma, v_t) = res.values_vectors();
227 let reconstructed = u.dot(&Array2::from_diag(&sigma).dot(&v_t));
228
229 close_l2(&a, &reconstructed, 1e-5);
230 }
231}