1use super::{error::*, layout::*, *};
12use cauchy::*;
13use num_traits::{ToPrimitive, Zero};
14
15pub struct SvdWork<T: Scalar> {
16 pub ju: JobSvd,
17 pub jvt: JobSvd,
18 pub layout: MatrixLayout,
19 pub s: Vec<MaybeUninit<T::Real>>,
20 pub u: Option<Vec<MaybeUninit<T>>>,
21 pub vt: Option<Vec<MaybeUninit<T>>>,
22 pub work: Vec<MaybeUninit<T>>,
23 pub rwork: Option<Vec<MaybeUninit<T::Real>>>,
24}
25
26#[derive(Debug, Clone)]
27pub struct SvdRef<'work, T: Scalar> {
28 pub s: &'work [T::Real],
29 pub u: Option<&'work [T]>,
30 pub vt: Option<&'work [T]>,
31}
32
33#[derive(Debug, Clone)]
34pub struct SvdOwned<T: Scalar> {
35 pub s: Vec<T::Real>,
36 pub u: Option<Vec<T>>,
37 pub vt: Option<Vec<T>>,
38}
39
40pub trait SvdWorkImpl: Sized {
41 type Elem: Scalar;
42 fn new(layout: MatrixLayout, calc_u: bool, calc_vt: bool) -> Result<Self>;
43 fn calc(&mut self, a: &mut [Self::Elem]) -> Result<SvdRef<Self::Elem>>;
44 fn eval(self, a: &mut [Self::Elem]) -> Result<SvdOwned<Self::Elem>>;
45}
46
47macro_rules! impl_svd_work_c {
48 ($s:ty, $svd:path) => {
49 impl SvdWorkImpl for SvdWork<$s> {
50 type Elem = $s;
51
52 fn new(layout: MatrixLayout, calc_u: bool, calc_vt: bool) -> Result<Self> {
53 let ju = match layout {
54 MatrixLayout::F { .. } => JobSvd::from_bool(calc_u),
55 MatrixLayout::C { .. } => JobSvd::from_bool(calc_vt),
56 };
57 let jvt = match layout {
58 MatrixLayout::F { .. } => JobSvd::from_bool(calc_vt),
59 MatrixLayout::C { .. } => JobSvd::from_bool(calc_u),
60 };
61
62 let m = layout.lda();
63 let mut u = match ju {
64 JobSvd::All => Some(vec_uninit((m * m) as usize)),
65 JobSvd::None => None,
66 _ => unimplemented!("SVD with partial vector output is not supported yet"),
67 };
68
69 let n = layout.len();
70 let mut vt = match jvt {
71 JobSvd::All => Some(vec_uninit((n * n) as usize)),
72 JobSvd::None => None,
73 _ => unimplemented!("SVD with partial vector output is not supported yet"),
74 };
75
76 let k = std::cmp::min(m, n);
77 let mut s = vec_uninit(k as usize);
78 let mut rwork = vec_uninit(5 * k as usize);
79
80 let mut info = 0;
82 let mut work_size = [Self::Elem::zero()];
83 unsafe {
84 $svd(
85 ju.as_ptr(),
86 jvt.as_ptr(),
87 &m,
88 &n,
89 std::ptr::null_mut(),
90 &m,
91 AsPtr::as_mut_ptr(&mut s),
92 AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
93 &m,
94 AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
95 &n,
96 AsPtr::as_mut_ptr(&mut work_size),
97 &(-1),
98 AsPtr::as_mut_ptr(&mut rwork),
99 &mut info,
100 );
101 }
102 info.as_lapack_result()?;
103 let lwork = work_size[0].to_usize().unwrap();
104 let work = vec_uninit(lwork);
105 Ok(SvdWork {
106 layout,
107 ju,
108 jvt,
109 s,
110 u,
111 vt,
112 work,
113 rwork: Some(rwork),
114 })
115 }
116
117 fn calc(&mut self, a: &mut [Self::Elem]) -> Result<SvdRef<Self::Elem>> {
118 let m = self.layout.lda();
119 let n = self.layout.len();
120 let lwork = self.work.len().to_i32().unwrap();
121
122 let mut info = 0;
123 unsafe {
124 $svd(
125 self.ju.as_ptr(),
126 self.jvt.as_ptr(),
127 &m,
128 &n,
129 AsPtr::as_mut_ptr(a),
130 &m,
131 AsPtr::as_mut_ptr(&mut self.s),
132 AsPtr::as_mut_ptr(
133 self.u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
134 ),
135 &m,
136 AsPtr::as_mut_ptr(
137 self.vt
138 .as_mut()
139 .map(|x| x.as_mut_slice())
140 .unwrap_or(&mut []),
141 ),
142 &n,
143 AsPtr::as_mut_ptr(&mut self.work),
144 &(lwork as i32),
145 AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()),
146 &mut info,
147 );
148 }
149 info.as_lapack_result()?;
150
151 let s = unsafe { self.s.slice_assume_init_ref() };
152 let u = self
153 .u
154 .as_ref()
155 .map(|v| unsafe { v.slice_assume_init_ref() });
156 let vt = self
157 .vt
158 .as_ref()
159 .map(|v| unsafe { v.slice_assume_init_ref() });
160
161 match self.layout {
162 MatrixLayout::F { .. } => Ok(SvdRef { s, u, vt }),
163 MatrixLayout::C { .. } => Ok(SvdRef { s, u: vt, vt: u }),
164 }
165 }
166
167 fn eval(mut self, a: &mut [Self::Elem]) -> Result<SvdOwned<Self::Elem>> {
168 let _ref = self.calc(a)?;
169 let s = unsafe { self.s.assume_init() };
170 let u = self.u.map(|v| unsafe { v.assume_init() });
171 let vt = self.vt.map(|v| unsafe { v.assume_init() });
172 match self.layout {
173 MatrixLayout::F { .. } => Ok(SvdOwned { s, u, vt }),
174 MatrixLayout::C { .. } => Ok(SvdOwned { s, u: vt, vt: u }),
175 }
176 }
177 }
178 };
179}
180impl_svd_work_c!(c64, lapack_sys::zgesvd_);
181impl_svd_work_c!(c32, lapack_sys::cgesvd_);
182
183macro_rules! impl_svd_work_r {
184 ($s:ty, $svd:path) => {
185 impl SvdWorkImpl for SvdWork<$s> {
186 type Elem = $s;
187
188 fn new(layout: MatrixLayout, calc_u: bool, calc_vt: bool) -> Result<Self> {
189 let ju = match layout {
190 MatrixLayout::F { .. } => JobSvd::from_bool(calc_u),
191 MatrixLayout::C { .. } => JobSvd::from_bool(calc_vt),
192 };
193 let jvt = match layout {
194 MatrixLayout::F { .. } => JobSvd::from_bool(calc_vt),
195 MatrixLayout::C { .. } => JobSvd::from_bool(calc_u),
196 };
197
198 let m = layout.lda();
199 let mut u = match ju {
200 JobSvd::All => Some(vec_uninit((m * m) as usize)),
201 JobSvd::None => None,
202 _ => unimplemented!("SVD with partial vector output is not supported yet"),
203 };
204
205 let n = layout.len();
206 let mut vt = match jvt {
207 JobSvd::All => Some(vec_uninit((n * n) as usize)),
208 JobSvd::None => None,
209 _ => unimplemented!("SVD with partial vector output is not supported yet"),
210 };
211
212 let k = std::cmp::min(m, n);
213 let mut s = vec_uninit(k as usize);
214
215 let mut info = 0;
217 let mut work_size = [Self::Elem::zero()];
218 unsafe {
219 $svd(
220 ju.as_ptr(),
221 jvt.as_ptr(),
222 &m,
223 &n,
224 std::ptr::null_mut(),
225 &m,
226 AsPtr::as_mut_ptr(&mut s),
227 AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
228 &m,
229 AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
230 &n,
231 AsPtr::as_mut_ptr(&mut work_size),
232 &(-1),
233 &mut info,
234 );
235 }
236 info.as_lapack_result()?;
237 let lwork = work_size[0].to_usize().unwrap();
238 let work = vec_uninit(lwork);
239 Ok(SvdWork {
240 layout,
241 ju,
242 jvt,
243 s,
244 u,
245 vt,
246 work,
247 rwork: None,
248 })
249 }
250
251 fn calc(&mut self, a: &mut [Self::Elem]) -> Result<SvdRef<Self::Elem>> {
252 let m = self.layout.lda();
253 let n = self.layout.len();
254 let lwork = self.work.len().to_i32().unwrap();
255
256 let mut info = 0;
257 unsafe {
258 $svd(
259 self.ju.as_ptr(),
260 self.jvt.as_ptr(),
261 &m,
262 &n,
263 AsPtr::as_mut_ptr(a),
264 &m,
265 AsPtr::as_mut_ptr(&mut self.s),
266 AsPtr::as_mut_ptr(
267 self.u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
268 ),
269 &m,
270 AsPtr::as_mut_ptr(
271 self.vt
272 .as_mut()
273 .map(|x| x.as_mut_slice())
274 .unwrap_or(&mut []),
275 ),
276 &n,
277 AsPtr::as_mut_ptr(&mut self.work),
278 &(lwork as i32),
279 &mut info,
280 );
281 }
282 info.as_lapack_result()?;
283
284 let s = unsafe { self.s.slice_assume_init_ref() };
285 let u = self
286 .u
287 .as_ref()
288 .map(|v| unsafe { v.slice_assume_init_ref() });
289 let vt = self
290 .vt
291 .as_ref()
292 .map(|v| unsafe { v.slice_assume_init_ref() });
293
294 match self.layout {
295 MatrixLayout::F { .. } => Ok(SvdRef { s, u, vt }),
296 MatrixLayout::C { .. } => Ok(SvdRef { s, u: vt, vt: u }),
297 }
298 }
299
300 fn eval(mut self, a: &mut [Self::Elem]) -> Result<SvdOwned<Self::Elem>> {
301 let _ref = self.calc(a)?;
302 let s = unsafe { self.s.assume_init() };
303 let u = self.u.map(|v| unsafe { v.assume_init() });
304 let vt = self.vt.map(|v| unsafe { v.assume_init() });
305 match self.layout {
306 MatrixLayout::F { .. } => Ok(SvdOwned { s, u, vt }),
307 MatrixLayout::C { .. } => Ok(SvdOwned { s, u: vt, vt: u }),
308 }
309 }
310 }
311 };
312}
313impl_svd_work_r!(f64, lapack_sys::dgesvd_);
314impl_svd_work_r!(f32, lapack_sys::sgesvd_);