1use crate::{error::*, layout::MatrixLayout, *};
7use cauchy::*;
8use num_traits::{ToPrimitive, Zero};
9
10pub struct BkWork<T: Scalar> {
11 pub layout: MatrixLayout,
12 pub work: Vec<MaybeUninit<T>>,
13 pub ipiv: Vec<MaybeUninit<i32>>,
14}
15
16pub trait BkWorkImpl: Sized {
26 type Elem: Scalar;
27 fn new(l: MatrixLayout) -> Result<Self>;
28 fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<&[i32]>;
29 fn eval(self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<Pivot>;
30}
31
32macro_rules! impl_bk_work {
33 ($s:ty, $trf:path) => {
34 impl BkWorkImpl for BkWork<$s> {
35 type Elem = $s;
36
37 fn new(layout: MatrixLayout) -> Result<Self> {
38 let (n, _) = layout.size();
39 let ipiv = vec_uninit(n as usize);
40 let mut info = 0;
41 let mut work_size = [Self::Elem::zero()];
42 unsafe {
43 $trf(
44 UPLO::Upper.as_ptr(),
45 &n,
46 std::ptr::null_mut(),
47 &layout.lda(),
48 std::ptr::null_mut(),
49 AsPtr::as_mut_ptr(&mut work_size),
50 &(-1),
51 &mut info,
52 )
53 };
54 info.as_lapack_result()?;
55 let lwork = work_size[0].to_usize().unwrap();
56 let work = vec_uninit(lwork);
57 Ok(BkWork { layout, work, ipiv })
58 }
59
60 fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<&[i32]> {
61 let (n, _) = self.layout.size();
62 let lwork = self.work.len().to_i32().unwrap();
63 if lwork == 0 {
64 return Ok(&[]);
65 }
66 let mut info = 0;
67 unsafe {
68 $trf(
69 uplo.as_ptr(),
70 &n,
71 AsPtr::as_mut_ptr(a),
72 &self.layout.lda(),
73 AsPtr::as_mut_ptr(&mut self.ipiv),
74 AsPtr::as_mut_ptr(&mut self.work),
75 &lwork,
76 &mut info,
77 )
78 };
79 info.as_lapack_result()?;
80 Ok(unsafe { self.ipiv.slice_assume_init_ref() })
81 }
82
83 fn eval(mut self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<Pivot> {
84 let _ref = self.calc(uplo, a)?;
85 Ok(unsafe { self.ipiv.assume_init() })
86 }
87 }
88 };
89}
90impl_bk_work!(c64, lapack_sys::zhetrf_);
91impl_bk_work!(c32, lapack_sys::chetrf_);
92impl_bk_work!(f64, lapack_sys::dsytrf_);
93impl_bk_work!(f32, lapack_sys::ssytrf_);
94
95pub struct InvhWork<T: Scalar> {
96 pub layout: MatrixLayout,
97 pub work: Vec<MaybeUninit<T>>,
98}
99
100pub trait InvhWorkImpl: Sized {
110 type Elem;
111 fn new(layout: MatrixLayout) -> Result<Self>;
112 fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()>;
113}
114
115macro_rules! impl_invh_work {
116 ($s:ty, $tri:path) => {
117 impl InvhWorkImpl for InvhWork<$s> {
118 type Elem = $s;
119
120 fn new(layout: MatrixLayout) -> Result<Self> {
121 let (n, _) = layout.size();
122 let work = vec_uninit(n as usize);
123 Ok(InvhWork { layout, work })
124 }
125
126 fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()> {
127 let (n, _) = self.layout.size();
128 let mut info = 0;
129 unsafe {
130 $tri(
131 uplo.as_ptr(),
132 &n,
133 AsPtr::as_mut_ptr(a),
134 &self.layout.lda(),
135 ipiv.as_ptr(),
136 AsPtr::as_mut_ptr(&mut self.work),
137 &mut info,
138 )
139 };
140 info.as_lapack_result()?;
141 Ok(())
142 }
143 }
144 };
145}
146impl_invh_work!(c64, lapack_sys::zhetri_);
147impl_invh_work!(c32, lapack_sys::chetri_);
148impl_invh_work!(f64, lapack_sys::dsytri_);
149impl_invh_work!(f32, lapack_sys::ssytri_);
150
151pub trait SolvehImpl: Scalar {
161 fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>;
162}
163
164macro_rules! impl_solveh_ {
165 ($s:ty, $trs:path) => {
166 impl SolvehImpl for $s {
167 fn solveh(
168 l: MatrixLayout,
169 uplo: UPLO,
170 a: &[Self],
171 ipiv: &Pivot,
172 b: &mut [Self],
173 ) -> Result<()> {
174 let (n, _) = l.size();
175 let mut info = 0;
176 unsafe {
177 $trs(
178 uplo.as_ptr(),
179 &n,
180 &1,
181 AsPtr::as_ptr(a),
182 &l.lda(),
183 ipiv.as_ptr(),
184 AsPtr::as_mut_ptr(b),
185 &n,
186 &mut info,
187 )
188 };
189 info.as_lapack_result()?;
190 Ok(())
191 }
192 }
193 };
194}
195
196impl_solveh_!(c64, lapack_sys::zhetrs_);
197impl_solveh_!(c32, lapack_sys::chetrs_);
198impl_solveh_!(f64, lapack_sys::dsytrs_);
199impl_solveh_!(f32, lapack_sys::ssytrs_);