1use crate::{error::*, layout::MatrixLayout, *};
4use cauchy::*;
5use num_traits::{ToPrimitive, Zero};
6
7pub struct HouseholderWork<T: Scalar> {
8 pub m: i32,
9 pub n: i32,
10 pub layout: MatrixLayout,
11 pub tau: Vec<MaybeUninit<T>>,
12 pub work: Vec<MaybeUninit<T>>,
13}
14
15pub trait HouseholderWorkImpl: Sized {
16 type Elem: Scalar;
17 fn new(l: MatrixLayout) -> Result<Self>;
18 fn calc(&mut self, a: &mut [Self::Elem]) -> Result<&[Self::Elem]>;
19 fn eval(self, a: &mut [Self::Elem]) -> Result<Vec<Self::Elem>>;
20}
21
22macro_rules! impl_householder_work {
23 ($s:ty, $qrf:path, $lqf: path) => {
24 impl HouseholderWorkImpl for HouseholderWork<$s> {
25 type Elem = $s;
26
27 fn new(layout: MatrixLayout) -> Result<Self> {
28 let m = layout.lda();
29 let n = layout.len();
30 let k = m.min(n);
31 let mut tau = vec_uninit(k as usize);
32 let mut info = 0;
33 let mut work_size = [Self::Elem::zero()];
34 match layout {
35 MatrixLayout::F { .. } => unsafe {
36 $qrf(
37 &m,
38 &n,
39 std::ptr::null_mut(),
40 &m,
41 AsPtr::as_mut_ptr(&mut tau),
42 AsPtr::as_mut_ptr(&mut work_size),
43 &(-1),
44 &mut info,
45 )
46 },
47 MatrixLayout::C { .. } => unsafe {
48 $lqf(
49 &m,
50 &n,
51 std::ptr::null_mut(),
52 &m,
53 AsPtr::as_mut_ptr(&mut tau),
54 AsPtr::as_mut_ptr(&mut work_size),
55 &(-1),
56 &mut info,
57 )
58 },
59 }
60 info.as_lapack_result()?;
61 let lwork = work_size[0].to_usize().unwrap();
62 let work = vec_uninit(lwork);
63 Ok(HouseholderWork {
64 n,
65 m,
66 layout,
67 tau,
68 work,
69 })
70 }
71
72 fn calc(&mut self, a: &mut [Self::Elem]) -> Result<&[Self::Elem]> {
73 let lwork = self.work.len().to_i32().unwrap();
74 let mut info = 0;
75 match self.layout {
76 MatrixLayout::F { .. } => unsafe {
77 $qrf(
78 &self.m,
79 &self.n,
80 AsPtr::as_mut_ptr(a),
81 &self.m,
82 AsPtr::as_mut_ptr(&mut self.tau),
83 AsPtr::as_mut_ptr(&mut self.work),
84 &lwork,
85 &mut info,
86 );
87 },
88 MatrixLayout::C { .. } => unsafe {
89 $lqf(
90 &self.m,
91 &self.n,
92 AsPtr::as_mut_ptr(a),
93 &self.m,
94 AsPtr::as_mut_ptr(&mut self.tau),
95 AsPtr::as_mut_ptr(&mut self.work),
96 &lwork,
97 &mut info,
98 );
99 },
100 }
101 info.as_lapack_result()?;
102 Ok(unsafe { self.tau.slice_assume_init_ref() })
103 }
104
105 fn eval(mut self, a: &mut [Self::Elem]) -> Result<Vec<Self::Elem>> {
106 let _eig = self.calc(a)?;
107 Ok(unsafe { self.tau.assume_init() })
108 }
109 }
110 };
111}
112impl_householder_work!(c64, lapack_sys::zgeqrf_, lapack_sys::zgelqf_);
113impl_householder_work!(c32, lapack_sys::cgeqrf_, lapack_sys::cgelqf_);
114impl_householder_work!(f64, lapack_sys::dgeqrf_, lapack_sys::dgelqf_);
115impl_householder_work!(f32, lapack_sys::sgeqrf_, lapack_sys::sgelqf_);
116
117pub struct QWork<T: Scalar> {
118 pub layout: MatrixLayout,
119 pub work: Vec<MaybeUninit<T>>,
120}
121
122pub trait QWorkImpl: Sized {
123 type Elem: Scalar;
124 fn new(layout: MatrixLayout) -> Result<Self>;
125 fn calc(&mut self, a: &mut [Self::Elem], tau: &[Self::Elem]) -> Result<()>;
126}
127
128macro_rules! impl_q_work {
129 ($s:ty, $gqr:path, $glq:path) => {
130 impl QWorkImpl for QWork<$s> {
131 type Elem = $s;
132
133 fn new(layout: MatrixLayout) -> Result<Self> {
134 let m = layout.lda();
135 let n = layout.len();
136 let k = m.min(n);
137 let mut info = 0;
138 let mut work_size = [Self::Elem::zero()];
139 match layout {
140 MatrixLayout::F { .. } => unsafe {
141 $gqr(
142 &m,
143 &k,
144 &k,
145 std::ptr::null_mut(),
146 &m,
147 std::ptr::null_mut(),
148 AsPtr::as_mut_ptr(&mut work_size),
149 &(-1),
150 &mut info,
151 )
152 },
153 MatrixLayout::C { .. } => unsafe {
154 $glq(
155 &k,
156 &n,
157 &k,
158 std::ptr::null_mut(),
159 &m,
160 std::ptr::null_mut(),
161 AsPtr::as_mut_ptr(&mut work_size),
162 &(-1),
163 &mut info,
164 )
165 },
166 }
167 let lwork = work_size[0].to_usize().unwrap();
168 let work = vec_uninit(lwork);
169 Ok(QWork { layout, work })
170 }
171
172 fn calc(&mut self, a: &mut [Self::Elem], tau: &[Self::Elem]) -> Result<()> {
173 let m = self.layout.lda();
174 let n = self.layout.len();
175 let k = m.min(n);
176 let lwork = self.work.len().to_i32().unwrap();
177 let mut info = 0;
178 match self.layout {
179 MatrixLayout::F { .. } => unsafe {
180 $gqr(
181 &m,
182 &k,
183 &k,
184 AsPtr::as_mut_ptr(a),
185 &m,
186 AsPtr::as_ptr(&tau),
187 AsPtr::as_mut_ptr(&mut self.work),
188 &lwork,
189 &mut info,
190 )
191 },
192 MatrixLayout::C { .. } => unsafe {
193 $glq(
194 &k,
195 &n,
196 &k,
197 AsPtr::as_mut_ptr(a),
198 &m,
199 AsPtr::as_ptr(&tau),
200 AsPtr::as_mut_ptr(&mut self.work),
201 &lwork,
202 &mut info,
203 )
204 },
205 }
206 info.as_lapack_result()?;
207 Ok(())
208 }
209 }
210 };
211}
212
213impl_q_work!(c64, lapack_sys::zungqr_, lapack_sys::zunglq_);
214impl_q_work!(c32, lapack_sys::cungqr_, lapack_sys::cunglq_);
215impl_q_work!(f64, lapack_sys::dorgqr_, lapack_sys::dorglq_);
216impl_q_work!(f32, lapack_sys::sorgqr_, lapack_sys::sorglq_);