lax/
svddc.rs

1//! Compute singular value decomposition with divide-and-conquer algorithm
2//!
3//! LAPACK correspondance
4//! ----------------------
5//!
6//! | f32    | f64    | c32    | c64    |
7//! |:-------|:-------|:-------|:-------|
8//! | sgesdd | dgesdd | cgesdd | zgesdd |
9//!
10
11use crate::{error::*, layout::MatrixLayout, *};
12use cauchy::*;
13use num_traits::{ToPrimitive, Zero};
14
15pub struct SvdDcWork<T: Scalar> {
16    pub jobz: JobSvd,
17    pub layout: MatrixLayout,
18    pub s: Vec<MaybeUninit<T::Real>>,
19    pub u: Option<Vec<MaybeUninit<T>>>,
20    pub vt: Option<Vec<MaybeUninit<T>>>,
21    pub work: Vec<MaybeUninit<T>>,
22    pub iwork: Vec<MaybeUninit<i32>>,
23    pub rwork: Option<Vec<MaybeUninit<T::Real>>>,
24}
25
26pub trait SvdDcWorkImpl: Sized {
27    type Elem: Scalar;
28    fn new(layout: MatrixLayout, jobz: JobSvd) -> Result<Self>;
29    fn calc(&mut self, a: &mut [Self::Elem]) -> Result<SvdRef<Self::Elem>>;
30    fn eval(self, a: &mut [Self::Elem]) -> Result<SvdOwned<Self::Elem>>;
31}
32
33macro_rules! impl_svd_dc_work_c {
34    ($s:ty, $sdd:path) => {
35        impl SvdDcWorkImpl for SvdDcWork<$s> {
36            type Elem = $s;
37
38            fn new(layout: MatrixLayout, jobz: JobSvd) -> Result<Self> {
39                let m = layout.lda();
40                let n = layout.len();
41                let k = m.min(n);
42                let (u_col, vt_row) = match jobz {
43                    JobSvd::All | JobSvd::None => (m, n),
44                    JobSvd::Some => (k, k),
45                };
46
47                let mut s = vec_uninit(k as usize);
48                let (mut u, mut vt) = match jobz {
49                    JobSvd::All => (
50                        Some(vec_uninit((m * m) as usize)),
51                        Some(vec_uninit((n * n) as usize)),
52                    ),
53                    JobSvd::Some => (
54                        Some(vec_uninit((m * u_col) as usize)),
55                        Some(vec_uninit((n * vt_row) as usize)),
56                    ),
57                    JobSvd::None => (None, None),
58                };
59                let mut iwork = vec_uninit(8 * k as usize);
60
61                let mx = n.max(m) as usize;
62                let mn = n.min(m) as usize;
63                let lrwork = match jobz {
64                    JobSvd::None => 7 * mn,
65                    _ => std::cmp::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn),
66                };
67                let mut rwork = vec_uninit(lrwork);
68
69                let mut info = 0;
70                let mut work_size = [Self::Elem::zero()];
71                unsafe {
72                    $sdd(
73                        jobz.as_ptr(),
74                        &m,
75                        &n,
76                        std::ptr::null_mut(),
77                        &m,
78                        AsPtr::as_mut_ptr(&mut s),
79                        AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
80                        &m,
81                        AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
82                        &vt_row,
83                        AsPtr::as_mut_ptr(&mut work_size),
84                        &(-1),
85                        AsPtr::as_mut_ptr(&mut rwork),
86                        AsPtr::as_mut_ptr(&mut iwork),
87                        &mut info,
88                    );
89                }
90                info.as_lapack_result()?;
91                let lwork = work_size[0].to_usize().unwrap();
92                let work = vec_uninit(lwork);
93                Ok(SvdDcWork {
94                    layout,
95                    jobz,
96                    iwork,
97                    work,
98                    rwork: Some(rwork),
99                    u,
100                    vt,
101                    s,
102                })
103            }
104
105            fn calc(&mut self, a: &mut [Self::Elem]) -> Result<SvdRef<Self::Elem>> {
106                let m = self.layout.lda();
107                let n = self.layout.len();
108                let k = m.min(n);
109                let (_, vt_row) = match self.jobz {
110                    JobSvd::All | JobSvd::None => (m, n),
111                    JobSvd::Some => (k, k),
112                };
113                let lwork = self.work.len().to_i32().unwrap();
114
115                let mut info = 0;
116                unsafe {
117                    $sdd(
118                        self.jobz.as_ptr(),
119                        &m,
120                        &n,
121                        AsPtr::as_mut_ptr(a),
122                        &m,
123                        AsPtr::as_mut_ptr(&mut self.s),
124                        AsPtr::as_mut_ptr(
125                            self.u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
126                        ),
127                        &m,
128                        AsPtr::as_mut_ptr(
129                            self.vt
130                                .as_mut()
131                                .map(|x| x.as_mut_slice())
132                                .unwrap_or(&mut []),
133                        ),
134                        &vt_row,
135                        AsPtr::as_mut_ptr(&mut self.work),
136                        &lwork,
137                        AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()),
138                        AsPtr::as_mut_ptr(&mut self.iwork),
139                        &mut info,
140                    );
141                }
142                info.as_lapack_result()?;
143
144                let s = unsafe { self.s.slice_assume_init_ref() };
145                let u = self
146                    .u
147                    .as_ref()
148                    .map(|v| unsafe { v.slice_assume_init_ref() });
149                let vt = self
150                    .vt
151                    .as_ref()
152                    .map(|v| unsafe { v.slice_assume_init_ref() });
153
154                Ok(match self.layout {
155                    MatrixLayout::F { .. } => SvdRef { s, u, vt },
156                    MatrixLayout::C { .. } => SvdRef { s, u: vt, vt: u },
157                })
158            }
159
160            fn eval(mut self, a: &mut [Self::Elem]) -> Result<SvdOwned<Self::Elem>> {
161                let _ref = self.calc(a)?;
162                let s = unsafe { self.s.assume_init() };
163                let u = self.u.map(|v| unsafe { v.assume_init() });
164                let vt = self.vt.map(|v| unsafe { v.assume_init() });
165                Ok(match self.layout {
166                    MatrixLayout::F { .. } => SvdOwned { s, u, vt },
167                    MatrixLayout::C { .. } => SvdOwned { s, u: vt, vt: u },
168                })
169            }
170        }
171    };
172}
173impl_svd_dc_work_c!(c64, lapack_sys::zgesdd_);
174impl_svd_dc_work_c!(c32, lapack_sys::cgesdd_);
175
176macro_rules! impl_svd_dc_work_r {
177    ($s:ty, $sdd:path) => {
178        impl SvdDcWorkImpl for SvdDcWork<$s> {
179            type Elem = $s;
180
181            fn new(layout: MatrixLayout, jobz: JobSvd) -> Result<Self> {
182                let m = layout.lda();
183                let n = layout.len();
184                let k = m.min(n);
185                let (u_col, vt_row) = match jobz {
186                    JobSvd::All | JobSvd::None => (m, n),
187                    JobSvd::Some => (k, k),
188                };
189
190                let mut s = vec_uninit(k as usize);
191                let (mut u, mut vt) = match jobz {
192                    JobSvd::All => (
193                        Some(vec_uninit((m * m) as usize)),
194                        Some(vec_uninit((n * n) as usize)),
195                    ),
196                    JobSvd::Some => (
197                        Some(vec_uninit((m * u_col) as usize)),
198                        Some(vec_uninit((n * vt_row) as usize)),
199                    ),
200                    JobSvd::None => (None, None),
201                };
202                let mut iwork = vec_uninit(8 * k as usize);
203
204                let mut info = 0;
205                let mut work_size = [Self::Elem::zero()];
206                unsafe {
207                    $sdd(
208                        jobz.as_ptr(),
209                        &m,
210                        &n,
211                        std::ptr::null_mut(),
212                        &m,
213                        AsPtr::as_mut_ptr(&mut s),
214                        AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
215                        &m,
216                        AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
217                        &vt_row,
218                        AsPtr::as_mut_ptr(&mut work_size),
219                        &(-1),
220                        AsPtr::as_mut_ptr(&mut iwork),
221                        &mut info,
222                    );
223                }
224                info.as_lapack_result()?;
225                let lwork = work_size[0].to_usize().unwrap();
226                let work = vec_uninit(lwork);
227                Ok(SvdDcWork {
228                    layout,
229                    jobz,
230                    iwork,
231                    work,
232                    rwork: None,
233                    u,
234                    vt,
235                    s,
236                })
237            }
238
239            fn calc(&mut self, a: &mut [Self::Elem]) -> Result<SvdRef<Self::Elem>> {
240                let m = self.layout.lda();
241                let n = self.layout.len();
242                let k = m.min(n);
243                let (_, vt_row) = match self.jobz {
244                    JobSvd::All | JobSvd::None => (m, n),
245                    JobSvd::Some => (k, k),
246                };
247                let lwork = self.work.len().to_i32().unwrap();
248
249                let mut info = 0;
250                unsafe {
251                    $sdd(
252                        self.jobz.as_ptr(),
253                        &m,
254                        &n,
255                        AsPtr::as_mut_ptr(a),
256                        &m,
257                        AsPtr::as_mut_ptr(&mut self.s),
258                        AsPtr::as_mut_ptr(
259                            self.u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
260                        ),
261                        &m,
262                        AsPtr::as_mut_ptr(
263                            self.vt
264                                .as_mut()
265                                .map(|x| x.as_mut_slice())
266                                .unwrap_or(&mut []),
267                        ),
268                        &vt_row,
269                        AsPtr::as_mut_ptr(&mut self.work),
270                        &lwork,
271                        AsPtr::as_mut_ptr(&mut self.iwork),
272                        &mut info,
273                    );
274                }
275                info.as_lapack_result()?;
276
277                let s = unsafe { self.s.slice_assume_init_ref() };
278                let u = self
279                    .u
280                    .as_ref()
281                    .map(|v| unsafe { v.slice_assume_init_ref() });
282                let vt = self
283                    .vt
284                    .as_ref()
285                    .map(|v| unsafe { v.slice_assume_init_ref() });
286
287                Ok(match self.layout {
288                    MatrixLayout::F { .. } => SvdRef { s, u, vt },
289                    MatrixLayout::C { .. } => SvdRef { s, u: vt, vt: u },
290                })
291            }
292
293            fn eval(mut self, a: &mut [Self::Elem]) -> Result<SvdOwned<Self::Elem>> {
294                let _ref = self.calc(a)?;
295                let s = unsafe { self.s.assume_init() };
296                let u = self.u.map(|v| unsafe { v.assume_init() });
297                let vt = self.vt.map(|v| unsafe { v.assume_init() });
298                Ok(match self.layout {
299                    MatrixLayout::F { .. } => SvdOwned { s, u, vt },
300                    MatrixLayout::C { .. } => SvdOwned { s, u: vt, vt: u },
301                })
302            }
303        }
304    };
305}
306impl_svd_dc_work_r!(f64, lapack_sys::dgesdd_);
307impl_svd_dc_work_r!(f32, lapack_sys::sgesdd_);