1use crate::{error::*, layout::MatrixLayout, *};
12use cauchy::*;
13use num_traits::{ToPrimitive, Zero};
14
15pub struct SvdDcWork<T: Scalar> {
16 pub jobz: JobSvd,
17 pub layout: MatrixLayout,
18 pub s: Vec<MaybeUninit<T::Real>>,
19 pub u: Option<Vec<MaybeUninit<T>>>,
20 pub vt: Option<Vec<MaybeUninit<T>>>,
21 pub work: Vec<MaybeUninit<T>>,
22 pub iwork: Vec<MaybeUninit<i32>>,
23 pub rwork: Option<Vec<MaybeUninit<T::Real>>>,
24}
25
26pub trait SvdDcWorkImpl: Sized {
27 type Elem: Scalar;
28 fn new(layout: MatrixLayout, jobz: JobSvd) -> Result<Self>;
29 fn calc(&mut self, a: &mut [Self::Elem]) -> Result<SvdRef<Self::Elem>>;
30 fn eval(self, a: &mut [Self::Elem]) -> Result<SvdOwned<Self::Elem>>;
31}
32
33macro_rules! impl_svd_dc_work_c {
34 ($s:ty, $sdd:path) => {
35 impl SvdDcWorkImpl for SvdDcWork<$s> {
36 type Elem = $s;
37
38 fn new(layout: MatrixLayout, jobz: JobSvd) -> Result<Self> {
39 let m = layout.lda();
40 let n = layout.len();
41 let k = m.min(n);
42 let (u_col, vt_row) = match jobz {
43 JobSvd::All | JobSvd::None => (m, n),
44 JobSvd::Some => (k, k),
45 };
46
47 let mut s = vec_uninit(k as usize);
48 let (mut u, mut vt) = match jobz {
49 JobSvd::All => (
50 Some(vec_uninit((m * m) as usize)),
51 Some(vec_uninit((n * n) as usize)),
52 ),
53 JobSvd::Some => (
54 Some(vec_uninit((m * u_col) as usize)),
55 Some(vec_uninit((n * vt_row) as usize)),
56 ),
57 JobSvd::None => (None, None),
58 };
59 let mut iwork = vec_uninit(8 * k as usize);
60
61 let mx = n.max(m) as usize;
62 let mn = n.min(m) as usize;
63 let lrwork = match jobz {
64 JobSvd::None => 7 * mn,
65 _ => std::cmp::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn),
66 };
67 let mut rwork = vec_uninit(lrwork);
68
69 let mut info = 0;
70 let mut work_size = [Self::Elem::zero()];
71 unsafe {
72 $sdd(
73 jobz.as_ptr(),
74 &m,
75 &n,
76 std::ptr::null_mut(),
77 &m,
78 AsPtr::as_mut_ptr(&mut s),
79 AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
80 &m,
81 AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
82 &vt_row,
83 AsPtr::as_mut_ptr(&mut work_size),
84 &(-1),
85 AsPtr::as_mut_ptr(&mut rwork),
86 AsPtr::as_mut_ptr(&mut iwork),
87 &mut info,
88 );
89 }
90 info.as_lapack_result()?;
91 let lwork = work_size[0].to_usize().unwrap();
92 let work = vec_uninit(lwork);
93 Ok(SvdDcWork {
94 layout,
95 jobz,
96 iwork,
97 work,
98 rwork: Some(rwork),
99 u,
100 vt,
101 s,
102 })
103 }
104
105 fn calc(&mut self, a: &mut [Self::Elem]) -> Result<SvdRef<Self::Elem>> {
106 let m = self.layout.lda();
107 let n = self.layout.len();
108 let k = m.min(n);
109 let (_, vt_row) = match self.jobz {
110 JobSvd::All | JobSvd::None => (m, n),
111 JobSvd::Some => (k, k),
112 };
113 let lwork = self.work.len().to_i32().unwrap();
114
115 let mut info = 0;
116 unsafe {
117 $sdd(
118 self.jobz.as_ptr(),
119 &m,
120 &n,
121 AsPtr::as_mut_ptr(a),
122 &m,
123 AsPtr::as_mut_ptr(&mut self.s),
124 AsPtr::as_mut_ptr(
125 self.u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
126 ),
127 &m,
128 AsPtr::as_mut_ptr(
129 self.vt
130 .as_mut()
131 .map(|x| x.as_mut_slice())
132 .unwrap_or(&mut []),
133 ),
134 &vt_row,
135 AsPtr::as_mut_ptr(&mut self.work),
136 &lwork,
137 AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()),
138 AsPtr::as_mut_ptr(&mut self.iwork),
139 &mut info,
140 );
141 }
142 info.as_lapack_result()?;
143
144 let s = unsafe { self.s.slice_assume_init_ref() };
145 let u = self
146 .u
147 .as_ref()
148 .map(|v| unsafe { v.slice_assume_init_ref() });
149 let vt = self
150 .vt
151 .as_ref()
152 .map(|v| unsafe { v.slice_assume_init_ref() });
153
154 Ok(match self.layout {
155 MatrixLayout::F { .. } => SvdRef { s, u, vt },
156 MatrixLayout::C { .. } => SvdRef { s, u: vt, vt: u },
157 })
158 }
159
160 fn eval(mut self, a: &mut [Self::Elem]) -> Result<SvdOwned<Self::Elem>> {
161 let _ref = self.calc(a)?;
162 let s = unsafe { self.s.assume_init() };
163 let u = self.u.map(|v| unsafe { v.assume_init() });
164 let vt = self.vt.map(|v| unsafe { v.assume_init() });
165 Ok(match self.layout {
166 MatrixLayout::F { .. } => SvdOwned { s, u, vt },
167 MatrixLayout::C { .. } => SvdOwned { s, u: vt, vt: u },
168 })
169 }
170 }
171 };
172}
173impl_svd_dc_work_c!(c64, lapack_sys::zgesdd_);
174impl_svd_dc_work_c!(c32, lapack_sys::cgesdd_);
175
176macro_rules! impl_svd_dc_work_r {
177 ($s:ty, $sdd:path) => {
178 impl SvdDcWorkImpl for SvdDcWork<$s> {
179 type Elem = $s;
180
181 fn new(layout: MatrixLayout, jobz: JobSvd) -> Result<Self> {
182 let m = layout.lda();
183 let n = layout.len();
184 let k = m.min(n);
185 let (u_col, vt_row) = match jobz {
186 JobSvd::All | JobSvd::None => (m, n),
187 JobSvd::Some => (k, k),
188 };
189
190 let mut s = vec_uninit(k as usize);
191 let (mut u, mut vt) = match jobz {
192 JobSvd::All => (
193 Some(vec_uninit((m * m) as usize)),
194 Some(vec_uninit((n * n) as usize)),
195 ),
196 JobSvd::Some => (
197 Some(vec_uninit((m * u_col) as usize)),
198 Some(vec_uninit((n * vt_row) as usize)),
199 ),
200 JobSvd::None => (None, None),
201 };
202 let mut iwork = vec_uninit(8 * k as usize);
203
204 let mut info = 0;
205 let mut work_size = [Self::Elem::zero()];
206 unsafe {
207 $sdd(
208 jobz.as_ptr(),
209 &m,
210 &n,
211 std::ptr::null_mut(),
212 &m,
213 AsPtr::as_mut_ptr(&mut s),
214 AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
215 &m,
216 AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
217 &vt_row,
218 AsPtr::as_mut_ptr(&mut work_size),
219 &(-1),
220 AsPtr::as_mut_ptr(&mut iwork),
221 &mut info,
222 );
223 }
224 info.as_lapack_result()?;
225 let lwork = work_size[0].to_usize().unwrap();
226 let work = vec_uninit(lwork);
227 Ok(SvdDcWork {
228 layout,
229 jobz,
230 iwork,
231 work,
232 rwork: None,
233 u,
234 vt,
235 s,
236 })
237 }
238
239 fn calc(&mut self, a: &mut [Self::Elem]) -> Result<SvdRef<Self::Elem>> {
240 let m = self.layout.lda();
241 let n = self.layout.len();
242 let k = m.min(n);
243 let (_, vt_row) = match self.jobz {
244 JobSvd::All | JobSvd::None => (m, n),
245 JobSvd::Some => (k, k),
246 };
247 let lwork = self.work.len().to_i32().unwrap();
248
249 let mut info = 0;
250 unsafe {
251 $sdd(
252 self.jobz.as_ptr(),
253 &m,
254 &n,
255 AsPtr::as_mut_ptr(a),
256 &m,
257 AsPtr::as_mut_ptr(&mut self.s),
258 AsPtr::as_mut_ptr(
259 self.u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
260 ),
261 &m,
262 AsPtr::as_mut_ptr(
263 self.vt
264 .as_mut()
265 .map(|x| x.as_mut_slice())
266 .unwrap_or(&mut []),
267 ),
268 &vt_row,
269 AsPtr::as_mut_ptr(&mut self.work),
270 &lwork,
271 AsPtr::as_mut_ptr(&mut self.iwork),
272 &mut info,
273 );
274 }
275 info.as_lapack_result()?;
276
277 let s = unsafe { self.s.slice_assume_init_ref() };
278 let u = self
279 .u
280 .as_ref()
281 .map(|v| unsafe { v.slice_assume_init_ref() });
282 let vt = self
283 .vt
284 .as_ref()
285 .map(|v| unsafe { v.slice_assume_init_ref() });
286
287 Ok(match self.layout {
288 MatrixLayout::F { .. } => SvdRef { s, u, vt },
289 MatrixLayout::C { .. } => SvdRef { s, u: vt, vt: u },
290 })
291 }
292
293 fn eval(mut self, a: &mut [Self::Elem]) -> Result<SvdOwned<Self::Elem>> {
294 let _ref = self.calc(a)?;
295 let s = unsafe { self.s.assume_init() };
296 let u = self.u.map(|v| unsafe { v.assume_init() });
297 let vt = self.vt.map(|v| unsafe { v.assume_init() });
298 Ok(match self.layout {
299 MatrixLayout::F { .. } => SvdOwned { s, u, vt },
300 MatrixLayout::C { .. } => SvdOwned { s, u: vt, vt: u },
301 })
302 }
303 }
304 };
305}
306impl_svd_dc_work_r!(f64, lapack_sys::dgesdd_);
307impl_svd_dc_work_r!(f32, lapack_sys::sgesdd_);