1use ndarray::*;
37
38use crate::diagonal::*;
39use crate::error::*;
40use crate::layout::*;
41use crate::operator::LinearOperator;
42use crate::types::*;
43use crate::UPLO;
44
45pub trait Eigh {
47 type EigVal;
48 type EigVec;
49 fn eigh(&self, uplo: UPLO) -> Result<(Self::EigVal, Self::EigVec)>;
50}
51
52pub trait EighInplace {
54 type EigVal;
55 fn eigh_inplace(&mut self, uplo: UPLO) -> Result<(Self::EigVal, &mut Self)>;
56}
57
58pub trait EighInto: Sized {
60 type EigVal;
61 fn eigh_into(self, uplo: UPLO) -> Result<(Self::EigVal, Self)>;
62}
63
64impl<A, S> EighInto for ArrayBase<S, Ix2>
65where
66 A: Scalar + Lapack,
67 S: DataMut<Elem = A>,
68{
69 type EigVal = Array1<A::Real>;
70
71 fn eigh_into(mut self, uplo: UPLO) -> Result<(Self::EigVal, Self)> {
72 let (val, _) = self.eigh_inplace(uplo)?;
73 Ok((val, self))
74 }
75}
76
77impl<A, S, S2> EighInto for (ArrayBase<S, Ix2>, ArrayBase<S2, Ix2>)
78where
79 A: Scalar + Lapack,
80 S: DataMut<Elem = A>,
81 S2: DataMut<Elem = A>,
82{
83 type EigVal = Array1<A::Real>;
84
85 fn eigh_into(mut self, uplo: UPLO) -> Result<(Self::EigVal, Self)> {
86 let (val, _) = self.eigh_inplace(uplo)?;
87 Ok((val, self))
88 }
89}
90
91impl<A, S> Eigh for ArrayBase<S, Ix2>
92where
93 A: Scalar + Lapack,
94 S: Data<Elem = A>,
95{
96 type EigVal = Array1<A::Real>;
97 type EigVec = Array2<A>;
98
99 fn eigh(&self, uplo: UPLO) -> Result<(Self::EigVal, Self::EigVec)> {
100 let a = self.to_owned();
101 a.eigh_into(uplo)
102 }
103}
104
105impl<A, S, S2> Eigh for (ArrayBase<S, Ix2>, ArrayBase<S2, Ix2>)
106where
107 A: Scalar + Lapack,
108 S: Data<Elem = A>,
109 S2: Data<Elem = A>,
110{
111 type EigVal = Array1<A::Real>;
112 type EigVec = (Array2<A>, Array2<A>);
113
114 fn eigh(&self, uplo: UPLO) -> Result<(Self::EigVal, Self::EigVec)> {
115 let (a, b) = (self.0.to_owned(), self.1.to_owned());
116 (a, b).eigh_into(uplo)
117 }
118}
119
120impl<A> EighInplace for ArrayRef<A, Ix2>
121where
122 A: Scalar + Lapack,
123{
124 type EigVal = Array1<A::Real>;
125
126 fn eigh_inplace(&mut self, uplo: UPLO) -> Result<(Self::EigVal, &mut Self)> {
127 let layout = self.square_layout()?;
128 match layout {
130 MatrixLayout::C { .. } => self.swap_axes(0, 1),
131 MatrixLayout::F { .. } => {}
132 }
133 let s = A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)?;
134 Ok((ArrayBase::from(s), self))
135 }
136}
137
138impl<A, S, S2> EighInplace for (ArrayBase<S, Ix2>, ArrayBase<S2, Ix2>)
139where
140 A: Scalar + Lapack,
141 S: DataMut<Elem = A>,
142 S2: DataMut<Elem = A>,
143{
144 type EigVal = Array1<A::Real>;
145
146 fn eigh_inplace(&mut self, uplo: UPLO) -> Result<(Self::EigVal, &mut Self)> {
152 assert_eq!(
153 self.0.shape(),
154 self.1.shape(),
155 "The shapes of the matrices must be identical.",
156 );
157 let layout = self.0.square_layout()?;
158 match layout {
160 MatrixLayout::C { .. } => self.0.swap_axes(0, 1),
161 MatrixLayout::F { .. } => {}
162 }
163
164 let layout = self.1.square_layout()?;
165 match layout {
166 MatrixLayout::C { .. } => self.1.swap_axes(0, 1),
167 MatrixLayout::F { .. } => {}
168 }
169
170 let s = A::eigh_generalized(
171 true,
172 self.0.square_layout()?,
173 uplo,
174 self.0.as_allocated_mut()?,
175 self.1.as_allocated_mut()?,
176 )?;
177
178 Ok((ArrayBase::from(s), self))
179 }
180}
181
182pub trait EigValsh {
184 type EigVal;
185 fn eigvalsh(&self, uplo: UPLO) -> Result<Self::EigVal>;
186}
187
188pub trait EigValshInto {
190 type EigVal;
191 fn eigvalsh_into(self, uplo: UPLO) -> Result<Self::EigVal>;
192}
193
194pub trait EigValshInplace {
196 type EigVal;
197 fn eigvalsh_inplace(&mut self, uplo: UPLO) -> Result<Self::EigVal>;
198}
199
200impl<A, S> EigValshInto for ArrayBase<S, Ix2>
201where
202 A: Scalar + Lapack,
203 S: DataMut<Elem = A>,
204{
205 type EigVal = Array1<A::Real>;
206
207 fn eigvalsh_into(mut self, uplo: UPLO) -> Result<Self::EigVal> {
208 self.eigvalsh_inplace(uplo)
209 }
210}
211
212impl<A, S> EigValsh for ArrayBase<S, Ix2>
213where
214 A: Scalar + Lapack,
215 S: Data<Elem = A>,
216{
217 type EigVal = Array1<A::Real>;
218
219 fn eigvalsh(&self, uplo: UPLO) -> Result<Self::EigVal> {
220 let a = self.to_owned();
221 a.eigvalsh_into(uplo)
222 }
223}
224
225impl<A, S> EigValshInplace for ArrayBase<S, Ix2>
226where
227 A: Scalar + Lapack,
228 S: DataMut<Elem = A>,
229{
230 type EigVal = Array1<A::Real>;
231
232 fn eigvalsh_inplace(&mut self, uplo: UPLO) -> Result<Self::EigVal> {
233 let s = A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)?;
234 Ok(ArrayBase::from(s))
235 }
236}
237
238pub trait SymmetricSqrt {
240 type Output;
241 fn ssqrt(&self, uplo: UPLO) -> Result<Self::Output>;
242}
243
244impl<A, S> SymmetricSqrt for ArrayBase<S, Ix2>
245where
246 A: Scalar + Lapack,
247 S: Data<Elem = A>,
248{
249 type Output = Array2<A>;
250
251 fn ssqrt(&self, uplo: UPLO) -> Result<Self::Output> {
252 let a = self.to_owned();
253 a.ssqrt_into(uplo)
254 }
255}
256
257pub trait SymmetricSqrtInto {
259 type Output;
260 fn ssqrt_into(self, uplo: UPLO) -> Result<Self::Output>;
261}
262
263impl<A, S> SymmetricSqrtInto for ArrayBase<S, Ix2>
264where
265 A: Scalar + Lapack,
266 S: DataMut<Elem = A> + DataOwned,
267{
268 type Output = Array2<A>;
269
270 fn ssqrt_into(self, uplo: UPLO) -> Result<Self::Output> {
271 let (e, v) = self.eigh_into(uplo)?;
272 let e_sqrt = Array::from_iter(e.iter().map(|r| Scalar::from_real(r.sqrt())));
273 let ev = e_sqrt.into_diagonal().apply2(&v.t());
274 Ok(v.apply2(&ev))
275 }
276}