1use ndarray::*;
37
38use crate::diagonal::*;
39use crate::error::*;
40use crate::layout::*;
41use crate::operator::LinearOperator;
42use crate::types::*;
43use crate::UPLO;
44
45pub trait Eigh {
47 type EigVal;
48 type EigVec;
49 fn eigh(&self, uplo: UPLO) -> Result<(Self::EigVal, Self::EigVec)>;
50}
51
52pub trait EighInplace {
54 type EigVal;
55 fn eigh_inplace(&mut self, uplo: UPLO) -> Result<(Self::EigVal, &mut Self)>;
56}
57
58pub trait EighInto: Sized {
60 type EigVal;
61 fn eigh_into(self, uplo: UPLO) -> Result<(Self::EigVal, Self)>;
62}
63
64impl<A, S> EighInto for ArrayBase<S, Ix2>
65where
66 A: Scalar + Lapack,
67 S: DataMut<Elem = A>,
68{
69 type EigVal = Array1<A::Real>;
70
71 fn eigh_into(mut self, uplo: UPLO) -> Result<(Self::EigVal, Self)> {
72 let (val, _) = self.eigh_inplace(uplo)?;
73 Ok((val, self))
74 }
75}
76
77impl<A, S, S2> EighInto for (ArrayBase<S, Ix2>, ArrayBase<S2, Ix2>)
78where
79 A: Scalar + Lapack,
80 S: DataMut<Elem = A>,
81 S2: DataMut<Elem = A>,
82{
83 type EigVal = Array1<A::Real>;
84
85 fn eigh_into(mut self, uplo: UPLO) -> Result<(Self::EigVal, Self)> {
86 let (val, _) = self.eigh_inplace(uplo)?;
87 Ok((val, self))
88 }
89}
90
91impl<A, S> Eigh for ArrayBase<S, Ix2>
92where
93 A: Scalar + Lapack,
94 S: Data<Elem = A>,
95{
96 type EigVal = Array1<A::Real>;
97 type EigVec = Array2<A>;
98
99 fn eigh(&self, uplo: UPLO) -> Result<(Self::EigVal, Self::EigVec)> {
100 let a = self.to_owned();
101 a.eigh_into(uplo)
102 }
103}
104
105impl<A, S, S2> Eigh for (ArrayBase<S, Ix2>, ArrayBase<S2, Ix2>)
106where
107 A: Scalar + Lapack,
108 S: Data<Elem = A>,
109 S2: Data<Elem = A>,
110{
111 type EigVal = Array1<A::Real>;
112 type EigVec = (Array2<A>, Array2<A>);
113
114 fn eigh(&self, uplo: UPLO) -> Result<(Self::EigVal, Self::EigVec)> {
115 let (a, b) = (self.0.to_owned(), self.1.to_owned());
116 (a, b).eigh_into(uplo)
117 }
118}
119
120impl<A, S> EighInplace for ArrayBase<S, Ix2>
121where
122 A: Scalar + Lapack,
123 S: DataMut<Elem = A>,
124{
125 type EigVal = Array1<A::Real>;
126
127 fn eigh_inplace(&mut self, uplo: UPLO) -> Result<(Self::EigVal, &mut Self)> {
128 let layout = self.square_layout()?;
129 match layout {
131 MatrixLayout::C { .. } => self.swap_axes(0, 1),
132 MatrixLayout::F { .. } => {}
133 }
134 let s = A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)?;
135 Ok((ArrayBase::from(s), self))
136 }
137}
138
139impl<A, S, S2> EighInplace for (ArrayBase<S, Ix2>, ArrayBase<S2, Ix2>)
140where
141 A: Scalar + Lapack,
142 S: DataMut<Elem = A>,
143 S2: DataMut<Elem = A>,
144{
145 type EigVal = Array1<A::Real>;
146
147 fn eigh_inplace(&mut self, uplo: UPLO) -> Result<(Self::EigVal, &mut Self)> {
153 assert_eq!(
154 self.0.shape(),
155 self.1.shape(),
156 "The shapes of the matrices must be identical.",
157 );
158 let layout = self.0.square_layout()?;
159 match layout {
161 MatrixLayout::C { .. } => self.0.swap_axes(0, 1),
162 MatrixLayout::F { .. } => {}
163 }
164
165 let layout = self.1.square_layout()?;
166 match layout {
167 MatrixLayout::C { .. } => self.1.swap_axes(0, 1),
168 MatrixLayout::F { .. } => {}
169 }
170
171 let s = A::eigh_generalized(
172 true,
173 self.0.square_layout()?,
174 uplo,
175 self.0.as_allocated_mut()?,
176 self.1.as_allocated_mut()?,
177 )?;
178
179 Ok((ArrayBase::from(s), self))
180 }
181}
182
183pub trait EigValsh {
185 type EigVal;
186 fn eigvalsh(&self, uplo: UPLO) -> Result<Self::EigVal>;
187}
188
189pub trait EigValshInto {
191 type EigVal;
192 fn eigvalsh_into(self, uplo: UPLO) -> Result<Self::EigVal>;
193}
194
195pub trait EigValshInplace {
197 type EigVal;
198 fn eigvalsh_inplace(&mut self, uplo: UPLO) -> Result<Self::EigVal>;
199}
200
201impl<A, S> EigValshInto for ArrayBase<S, Ix2>
202where
203 A: Scalar + Lapack,
204 S: DataMut<Elem = A>,
205{
206 type EigVal = Array1<A::Real>;
207
208 fn eigvalsh_into(mut self, uplo: UPLO) -> Result<Self::EigVal> {
209 self.eigvalsh_inplace(uplo)
210 }
211}
212
213impl<A, S> EigValsh for ArrayBase<S, Ix2>
214where
215 A: Scalar + Lapack,
216 S: Data<Elem = A>,
217{
218 type EigVal = Array1<A::Real>;
219
220 fn eigvalsh(&self, uplo: UPLO) -> Result<Self::EigVal> {
221 let a = self.to_owned();
222 a.eigvalsh_into(uplo)
223 }
224}
225
226impl<A, S> EigValshInplace for ArrayBase<S, Ix2>
227where
228 A: Scalar + Lapack,
229 S: DataMut<Elem = A>,
230{
231 type EigVal = Array1<A::Real>;
232
233 fn eigvalsh_inplace(&mut self, uplo: UPLO) -> Result<Self::EigVal> {
234 let s = A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)?;
235 Ok(ArrayBase::from(s))
236 }
237}
238
239pub trait SymmetricSqrt {
241 type Output;
242 fn ssqrt(&self, uplo: UPLO) -> Result<Self::Output>;
243}
244
245impl<A, S> SymmetricSqrt for ArrayBase<S, Ix2>
246where
247 A: Scalar + Lapack,
248 S: Data<Elem = A>,
249{
250 type Output = Array2<A>;
251
252 fn ssqrt(&self, uplo: UPLO) -> Result<Self::Output> {
253 let a = self.to_owned();
254 a.ssqrt_into(uplo)
255 }
256}
257
258pub trait SymmetricSqrtInto {
260 type Output;
261 fn ssqrt_into(self, uplo: UPLO) -> Result<Self::Output>;
262}
263
264impl<A, S> SymmetricSqrtInto for ArrayBase<S, Ix2>
265where
266 A: Scalar + Lapack,
267 S: DataMut<Elem = A> + DataOwned,
268{
269 type Output = Array2<A>;
270
271 fn ssqrt_into(self, uplo: UPLO) -> Result<Self::Output> {
272 let (e, v) = self.eigh_into(uplo)?;
273 let e_sqrt = Array::from_iter(e.iter().map(|r| Scalar::from_real(r.sqrt())));
274 let ev = e_sqrt.into_diagonal().apply2(&v.t());
275 Ok(v.apply2(&ev))
276 }
277}