lax/
qr.rs

1//! QR decomposition
2
3use crate::{error::*, layout::MatrixLayout, *};
4use cauchy::*;
5use num_traits::{ToPrimitive, Zero};
6
7pub struct HouseholderWork<T: Scalar> {
8    pub m: i32,
9    pub n: i32,
10    pub layout: MatrixLayout,
11    pub tau: Vec<MaybeUninit<T>>,
12    pub work: Vec<MaybeUninit<T>>,
13}
14
15pub trait HouseholderWorkImpl: Sized {
16    type Elem: Scalar;
17    fn new(l: MatrixLayout) -> Result<Self>;
18    fn calc(&mut self, a: &mut [Self::Elem]) -> Result<&[Self::Elem]>;
19    fn eval(self, a: &mut [Self::Elem]) -> Result<Vec<Self::Elem>>;
20}
21
22macro_rules! impl_householder_work {
23    ($s:ty, $qrf:path, $lqf: path) => {
24        impl HouseholderWorkImpl for HouseholderWork<$s> {
25            type Elem = $s;
26
27            fn new(layout: MatrixLayout) -> Result<Self> {
28                let m = layout.lda();
29                let n = layout.len();
30                let k = m.min(n);
31                let mut tau = vec_uninit(k as usize);
32                let mut info = 0;
33                let mut work_size = [Self::Elem::zero()];
34                match layout {
35                    MatrixLayout::F { .. } => unsafe {
36                        $qrf(
37                            &m,
38                            &n,
39                            std::ptr::null_mut(),
40                            &m,
41                            AsPtr::as_mut_ptr(&mut tau),
42                            AsPtr::as_mut_ptr(&mut work_size),
43                            &(-1),
44                            &mut info,
45                        )
46                    },
47                    MatrixLayout::C { .. } => unsafe {
48                        $lqf(
49                            &m,
50                            &n,
51                            std::ptr::null_mut(),
52                            &m,
53                            AsPtr::as_mut_ptr(&mut tau),
54                            AsPtr::as_mut_ptr(&mut work_size),
55                            &(-1),
56                            &mut info,
57                        )
58                    },
59                }
60                info.as_lapack_result()?;
61                let lwork = work_size[0].to_usize().unwrap();
62                let work = vec_uninit(lwork);
63                Ok(HouseholderWork {
64                    n,
65                    m,
66                    layout,
67                    tau,
68                    work,
69                })
70            }
71
72            fn calc(&mut self, a: &mut [Self::Elem]) -> Result<&[Self::Elem]> {
73                let lwork = self.work.len().to_i32().unwrap();
74                let mut info = 0;
75                match self.layout {
76                    MatrixLayout::F { .. } => unsafe {
77                        $qrf(
78                            &self.m,
79                            &self.n,
80                            AsPtr::as_mut_ptr(a),
81                            &self.m,
82                            AsPtr::as_mut_ptr(&mut self.tau),
83                            AsPtr::as_mut_ptr(&mut self.work),
84                            &lwork,
85                            &mut info,
86                        );
87                    },
88                    MatrixLayout::C { .. } => unsafe {
89                        $lqf(
90                            &self.m,
91                            &self.n,
92                            AsPtr::as_mut_ptr(a),
93                            &self.m,
94                            AsPtr::as_mut_ptr(&mut self.tau),
95                            AsPtr::as_mut_ptr(&mut self.work),
96                            &lwork,
97                            &mut info,
98                        );
99                    },
100                }
101                info.as_lapack_result()?;
102                Ok(unsafe { self.tau.slice_assume_init_ref() })
103            }
104
105            fn eval(mut self, a: &mut [Self::Elem]) -> Result<Vec<Self::Elem>> {
106                let _eig = self.calc(a)?;
107                Ok(unsafe { self.tau.assume_init() })
108            }
109        }
110    };
111}
112impl_householder_work!(c64, lapack_sys::zgeqrf_, lapack_sys::zgelqf_);
113impl_householder_work!(c32, lapack_sys::cgeqrf_, lapack_sys::cgelqf_);
114impl_householder_work!(f64, lapack_sys::dgeqrf_, lapack_sys::dgelqf_);
115impl_householder_work!(f32, lapack_sys::sgeqrf_, lapack_sys::sgelqf_);
116
117pub struct QWork<T: Scalar> {
118    pub layout: MatrixLayout,
119    pub work: Vec<MaybeUninit<T>>,
120}
121
122pub trait QWorkImpl: Sized {
123    type Elem: Scalar;
124    fn new(layout: MatrixLayout) -> Result<Self>;
125    fn calc(&mut self, a: &mut [Self::Elem], tau: &[Self::Elem]) -> Result<()>;
126}
127
128macro_rules! impl_q_work {
129    ($s:ty, $gqr:path, $glq:path) => {
130        impl QWorkImpl for QWork<$s> {
131            type Elem = $s;
132
133            fn new(layout: MatrixLayout) -> Result<Self> {
134                let m = layout.lda();
135                let n = layout.len();
136                let k = m.min(n);
137                let mut info = 0;
138                let mut work_size = [Self::Elem::zero()];
139                match layout {
140                    MatrixLayout::F { .. } => unsafe {
141                        $gqr(
142                            &m,
143                            &k,
144                            &k,
145                            std::ptr::null_mut(),
146                            &m,
147                            std::ptr::null_mut(),
148                            AsPtr::as_mut_ptr(&mut work_size),
149                            &(-1),
150                            &mut info,
151                        )
152                    },
153                    MatrixLayout::C { .. } => unsafe {
154                        $glq(
155                            &k,
156                            &n,
157                            &k,
158                            std::ptr::null_mut(),
159                            &m,
160                            std::ptr::null_mut(),
161                            AsPtr::as_mut_ptr(&mut work_size),
162                            &(-1),
163                            &mut info,
164                        )
165                    },
166                }
167                let lwork = work_size[0].to_usize().unwrap();
168                let work = vec_uninit(lwork);
169                Ok(QWork { layout, work })
170            }
171
172            fn calc(&mut self, a: &mut [Self::Elem], tau: &[Self::Elem]) -> Result<()> {
173                let m = self.layout.lda();
174                let n = self.layout.len();
175                let k = m.min(n);
176                let lwork = self.work.len().to_i32().unwrap();
177                let mut info = 0;
178                match self.layout {
179                    MatrixLayout::F { .. } => unsafe {
180                        $gqr(
181                            &m,
182                            &k,
183                            &k,
184                            AsPtr::as_mut_ptr(a),
185                            &m,
186                            AsPtr::as_ptr(&tau),
187                            AsPtr::as_mut_ptr(&mut self.work),
188                            &lwork,
189                            &mut info,
190                        )
191                    },
192                    MatrixLayout::C { .. } => unsafe {
193                        $glq(
194                            &k,
195                            &n,
196                            &k,
197                            AsPtr::as_mut_ptr(a),
198                            &m,
199                            AsPtr::as_ptr(&tau),
200                            AsPtr::as_mut_ptr(&mut self.work),
201                            &lwork,
202                            &mut info,
203                        )
204                    },
205                }
206                info.as_lapack_result()?;
207                Ok(())
208            }
209        }
210    };
211}
212
213impl_q_work!(c64, lapack_sys::zungqr_, lapack_sys::zunglq_);
214impl_q_work!(c32, lapack_sys::cungqr_, lapack_sys::cunglq_);
215impl_q_work!(f64, lapack_sys::dorgqr_, lapack_sys::dorglq_);
216impl_q_work!(f32, lapack_sys::sorgqr_, lapack_sys::sorglq_);