lax/
svd.rs

1//! Singular-value decomposition
2//!
3//! LAPACK correspondance
4//! ----------------------
5//!
6//! | f32    | f64    | c32    | c64    |
7//! |:-------|:-------|:-------|:-------|
8//! | sgesvd | dgesvd | cgesvd | zgesvd |
9//!
10
11use super::{error::*, layout::*, *};
12use cauchy::*;
13use num_traits::{ToPrimitive, Zero};
14
15pub struct SvdWork<T: Scalar> {
16    pub ju: JobSvd,
17    pub jvt: JobSvd,
18    pub layout: MatrixLayout,
19    pub s: Vec<MaybeUninit<T::Real>>,
20    pub u: Option<Vec<MaybeUninit<T>>>,
21    pub vt: Option<Vec<MaybeUninit<T>>>,
22    pub work: Vec<MaybeUninit<T>>,
23    pub rwork: Option<Vec<MaybeUninit<T::Real>>>,
24}
25
26#[derive(Debug, Clone)]
27pub struct SvdRef<'work, T: Scalar> {
28    pub s: &'work [T::Real],
29    pub u: Option<&'work [T]>,
30    pub vt: Option<&'work [T]>,
31}
32
33#[derive(Debug, Clone)]
34pub struct SvdOwned<T: Scalar> {
35    pub s: Vec<T::Real>,
36    pub u: Option<Vec<T>>,
37    pub vt: Option<Vec<T>>,
38}
39
40pub trait SvdWorkImpl: Sized {
41    type Elem: Scalar;
42    fn new(layout: MatrixLayout, calc_u: bool, calc_vt: bool) -> Result<Self>;
43    fn calc(&mut self, a: &mut [Self::Elem]) -> Result<SvdRef<Self::Elem>>;
44    fn eval(self, a: &mut [Self::Elem]) -> Result<SvdOwned<Self::Elem>>;
45}
46
47macro_rules! impl_svd_work_c {
48    ($s:ty, $svd:path) => {
49        impl SvdWorkImpl for SvdWork<$s> {
50            type Elem = $s;
51
52            fn new(layout: MatrixLayout, calc_u: bool, calc_vt: bool) -> Result<Self> {
53                let ju = match layout {
54                    MatrixLayout::F { .. } => JobSvd::from_bool(calc_u),
55                    MatrixLayout::C { .. } => JobSvd::from_bool(calc_vt),
56                };
57                let jvt = match layout {
58                    MatrixLayout::F { .. } => JobSvd::from_bool(calc_vt),
59                    MatrixLayout::C { .. } => JobSvd::from_bool(calc_u),
60                };
61
62                let m = layout.lda();
63                let mut u = match ju {
64                    JobSvd::All => Some(vec_uninit((m * m) as usize)),
65                    JobSvd::None => None,
66                    _ => unimplemented!("SVD with partial vector output is not supported yet"),
67                };
68
69                let n = layout.len();
70                let mut vt = match jvt {
71                    JobSvd::All => Some(vec_uninit((n * n) as usize)),
72                    JobSvd::None => None,
73                    _ => unimplemented!("SVD with partial vector output is not supported yet"),
74                };
75
76                let k = std::cmp::min(m, n);
77                let mut s = vec_uninit(k as usize);
78                let mut rwork = vec_uninit(5 * k as usize);
79
80                // eval work size
81                let mut info = 0;
82                let mut work_size = [Self::Elem::zero()];
83                unsafe {
84                    $svd(
85                        ju.as_ptr(),
86                        jvt.as_ptr(),
87                        &m,
88                        &n,
89                        std::ptr::null_mut(),
90                        &m,
91                        AsPtr::as_mut_ptr(&mut s),
92                        AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
93                        &m,
94                        AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
95                        &n,
96                        AsPtr::as_mut_ptr(&mut work_size),
97                        &(-1),
98                        AsPtr::as_mut_ptr(&mut rwork),
99                        &mut info,
100                    );
101                }
102                info.as_lapack_result()?;
103                let lwork = work_size[0].to_usize().unwrap();
104                let work = vec_uninit(lwork);
105                Ok(SvdWork {
106                    layout,
107                    ju,
108                    jvt,
109                    s,
110                    u,
111                    vt,
112                    work,
113                    rwork: Some(rwork),
114                })
115            }
116
117            fn calc(&mut self, a: &mut [Self::Elem]) -> Result<SvdRef<Self::Elem>> {
118                let m = self.layout.lda();
119                let n = self.layout.len();
120                let lwork = self.work.len().to_i32().unwrap();
121
122                let mut info = 0;
123                unsafe {
124                    $svd(
125                        self.ju.as_ptr(),
126                        self.jvt.as_ptr(),
127                        &m,
128                        &n,
129                        AsPtr::as_mut_ptr(a),
130                        &m,
131                        AsPtr::as_mut_ptr(&mut self.s),
132                        AsPtr::as_mut_ptr(
133                            self.u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
134                        ),
135                        &m,
136                        AsPtr::as_mut_ptr(
137                            self.vt
138                                .as_mut()
139                                .map(|x| x.as_mut_slice())
140                                .unwrap_or(&mut []),
141                        ),
142                        &n,
143                        AsPtr::as_mut_ptr(&mut self.work),
144                        &(lwork as i32),
145                        AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()),
146                        &mut info,
147                    );
148                }
149                info.as_lapack_result()?;
150
151                let s = unsafe { self.s.slice_assume_init_ref() };
152                let u = self
153                    .u
154                    .as_ref()
155                    .map(|v| unsafe { v.slice_assume_init_ref() });
156                let vt = self
157                    .vt
158                    .as_ref()
159                    .map(|v| unsafe { v.slice_assume_init_ref() });
160
161                match self.layout {
162                    MatrixLayout::F { .. } => Ok(SvdRef { s, u, vt }),
163                    MatrixLayout::C { .. } => Ok(SvdRef { s, u: vt, vt: u }),
164                }
165            }
166
167            fn eval(mut self, a: &mut [Self::Elem]) -> Result<SvdOwned<Self::Elem>> {
168                let _ref = self.calc(a)?;
169                let s = unsafe { self.s.assume_init() };
170                let u = self.u.map(|v| unsafe { v.assume_init() });
171                let vt = self.vt.map(|v| unsafe { v.assume_init() });
172                match self.layout {
173                    MatrixLayout::F { .. } => Ok(SvdOwned { s, u, vt }),
174                    MatrixLayout::C { .. } => Ok(SvdOwned { s, u: vt, vt: u }),
175                }
176            }
177        }
178    };
179}
180impl_svd_work_c!(c64, lapack_sys::zgesvd_);
181impl_svd_work_c!(c32, lapack_sys::cgesvd_);
182
183macro_rules! impl_svd_work_r {
184    ($s:ty, $svd:path) => {
185        impl SvdWorkImpl for SvdWork<$s> {
186            type Elem = $s;
187
188            fn new(layout: MatrixLayout, calc_u: bool, calc_vt: bool) -> Result<Self> {
189                let ju = match layout {
190                    MatrixLayout::F { .. } => JobSvd::from_bool(calc_u),
191                    MatrixLayout::C { .. } => JobSvd::from_bool(calc_vt),
192                };
193                let jvt = match layout {
194                    MatrixLayout::F { .. } => JobSvd::from_bool(calc_vt),
195                    MatrixLayout::C { .. } => JobSvd::from_bool(calc_u),
196                };
197
198                let m = layout.lda();
199                let mut u = match ju {
200                    JobSvd::All => Some(vec_uninit((m * m) as usize)),
201                    JobSvd::None => None,
202                    _ => unimplemented!("SVD with partial vector output is not supported yet"),
203                };
204
205                let n = layout.len();
206                let mut vt = match jvt {
207                    JobSvd::All => Some(vec_uninit((n * n) as usize)),
208                    JobSvd::None => None,
209                    _ => unimplemented!("SVD with partial vector output is not supported yet"),
210                };
211
212                let k = std::cmp::min(m, n);
213                let mut s = vec_uninit(k as usize);
214
215                // eval work size
216                let mut info = 0;
217                let mut work_size = [Self::Elem::zero()];
218                unsafe {
219                    $svd(
220                        ju.as_ptr(),
221                        jvt.as_ptr(),
222                        &m,
223                        &n,
224                        std::ptr::null_mut(),
225                        &m,
226                        AsPtr::as_mut_ptr(&mut s),
227                        AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
228                        &m,
229                        AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
230                        &n,
231                        AsPtr::as_mut_ptr(&mut work_size),
232                        &(-1),
233                        &mut info,
234                    );
235                }
236                info.as_lapack_result()?;
237                let lwork = work_size[0].to_usize().unwrap();
238                let work = vec_uninit(lwork);
239                Ok(SvdWork {
240                    layout,
241                    ju,
242                    jvt,
243                    s,
244                    u,
245                    vt,
246                    work,
247                    rwork: None,
248                })
249            }
250
251            fn calc(&mut self, a: &mut [Self::Elem]) -> Result<SvdRef<Self::Elem>> {
252                let m = self.layout.lda();
253                let n = self.layout.len();
254                let lwork = self.work.len().to_i32().unwrap();
255
256                let mut info = 0;
257                unsafe {
258                    $svd(
259                        self.ju.as_ptr(),
260                        self.jvt.as_ptr(),
261                        &m,
262                        &n,
263                        AsPtr::as_mut_ptr(a),
264                        &m,
265                        AsPtr::as_mut_ptr(&mut self.s),
266                        AsPtr::as_mut_ptr(
267                            self.u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
268                        ),
269                        &m,
270                        AsPtr::as_mut_ptr(
271                            self.vt
272                                .as_mut()
273                                .map(|x| x.as_mut_slice())
274                                .unwrap_or(&mut []),
275                        ),
276                        &n,
277                        AsPtr::as_mut_ptr(&mut self.work),
278                        &(lwork as i32),
279                        &mut info,
280                    );
281                }
282                info.as_lapack_result()?;
283
284                let s = unsafe { self.s.slice_assume_init_ref() };
285                let u = self
286                    .u
287                    .as_ref()
288                    .map(|v| unsafe { v.slice_assume_init_ref() });
289                let vt = self
290                    .vt
291                    .as_ref()
292                    .map(|v| unsafe { v.slice_assume_init_ref() });
293
294                match self.layout {
295                    MatrixLayout::F { .. } => Ok(SvdRef { s, u, vt }),
296                    MatrixLayout::C { .. } => Ok(SvdRef { s, u: vt, vt: u }),
297                }
298            }
299
300            fn eval(mut self, a: &mut [Self::Elem]) -> Result<SvdOwned<Self::Elem>> {
301                let _ref = self.calc(a)?;
302                let s = unsafe { self.s.assume_init() };
303                let u = self.u.map(|v| unsafe { v.assume_init() });
304                let vt = self.vt.map(|v| unsafe { v.assume_init() });
305                match self.layout {
306                    MatrixLayout::F { .. } => Ok(SvdOwned { s, u, vt }),
307                    MatrixLayout::C { .. } => Ok(SvdOwned { s, u: vt, vt: u }),
308                }
309            }
310        }
311    };
312}
313impl_svd_work_r!(f64, lapack_sys::dgesvd_);
314impl_svd_work_r!(f32, lapack_sys::sgesvd_);