1use super::convert::*;
6use super::error::*;
7use super::layout::*;
8use cauchy::Scalar;
9use lax::*;
10use ndarray::*;
11use num_traits::One;
12
13pub use lax::{LUFactorizedTridiagonal, Tridiagonal};
14
15pub trait ExtractTridiagonal<A: Scalar> {
17 fn extract_tridiagonal(&self) -> Result<Tridiagonal<A>>;
24}
25
26impl<A> ExtractTridiagonal<A> for ArrayRef<A, Ix2>
27where
28 A: Scalar + Lapack,
29{
30 fn extract_tridiagonal(&self) -> Result<Tridiagonal<A>> {
31 let l = self.square_layout()?;
32 let (n, _) = l.size();
33 if n < 2 {
34 return Err(LinalgError::NotStandardShape {
35 obj: "Tridiagonal",
36 rows: 1,
37 cols: 1,
38 });
39 }
40
41 let dl = self.slice(s![1..n, 0..n - 1]).diag().to_vec();
42 let d = self.diag().to_vec();
43 let du = self.slice(s![0..n - 1, 1..n]).diag().to_vec();
44 Ok(Tridiagonal { l, dl, d, du })
45 }
46}
47
48pub trait SolveTridiagonal<A: Scalar, D: Dimension> {
49 fn solve_tridiagonal(&self, b: &ArrayRef<A, D>) -> Result<Array<A, D>>;
53 fn solve_tridiagonal_into<S: DataMut<Elem = A>>(
57 &self,
58 b: ArrayBase<S, D>,
59 ) -> Result<ArrayBase<S, D>>;
60 fn solve_t_tridiagonal(&self, b: &ArrayRef<A, D>) -> Result<Array<A, D>>;
64 fn solve_t_tridiagonal_into<S: DataMut<Elem = A>>(
68 &self,
69 b: ArrayBase<S, D>,
70 ) -> Result<ArrayBase<S, D>>;
71 fn solve_h_tridiagonal(&self, b: &ArrayRef<A, D>) -> Result<Array<A, D>>;
75 fn solve_h_tridiagonal_into<S: DataMut<Elem = A>>(
79 &self,
80 b: ArrayBase<S, D>,
81 ) -> Result<ArrayBase<S, D>>;
82}
83
84pub trait SolveTridiagonalInplace<A: Scalar, D: Dimension> {
85 fn solve_tridiagonal_inplace<'a>(
90 &self,
91 b: &'a mut ArrayRef<A, D>,
92 ) -> Result<&'a mut ArrayRef<A, D>>;
93 fn solve_t_tridiagonal_inplace<'a>(
98 &self,
99 b: &'a mut ArrayRef<A, D>,
100 ) -> Result<&'a mut ArrayRef<A, D>>;
101 fn solve_h_tridiagonal_inplace<'a>(
106 &self,
107 b: &'a mut ArrayRef<A, D>,
108 ) -> Result<&'a mut ArrayRef<A, D>>;
109}
110
111impl<A> SolveTridiagonal<A, Ix2> for LUFactorizedTridiagonal<A>
112where
113 A: Scalar + Lapack,
114{
115 fn solve_tridiagonal(&self, b: &ArrayRef<A, Ix2>) -> Result<Array<A, Ix2>> {
116 let mut b = replicate(b);
117 self.solve_tridiagonal_inplace(&mut b)?;
118 Ok(b)
119 }
120 fn solve_tridiagonal_into<S: DataMut<Elem = A>>(
121 &self,
122 mut b: ArrayBase<S, Ix2>,
123 ) -> Result<ArrayBase<S, Ix2>> {
124 self.solve_tridiagonal_inplace(&mut b)?;
125 Ok(b)
126 }
127 fn solve_t_tridiagonal(&self, b: &ArrayRef<A, Ix2>) -> Result<Array<A, Ix2>> {
128 let mut b = replicate(b);
129 self.solve_t_tridiagonal_inplace(&mut b)?;
130 Ok(b)
131 }
132 fn solve_t_tridiagonal_into<S: DataMut<Elem = A>>(
133 &self,
134 mut b: ArrayBase<S, Ix2>,
135 ) -> Result<ArrayBase<S, Ix2>> {
136 self.solve_t_tridiagonal_inplace(&mut b)?;
137 Ok(b)
138 }
139 fn solve_h_tridiagonal(&self, b: &ArrayRef<A, Ix2>) -> Result<Array<A, Ix2>> {
140 let mut b = replicate(b);
141 self.solve_h_tridiagonal_inplace(&mut b)?;
142 Ok(b)
143 }
144 fn solve_h_tridiagonal_into<S: DataMut<Elem = A>>(
145 &self,
146 mut b: ArrayBase<S, Ix2>,
147 ) -> Result<ArrayBase<S, Ix2>> {
148 self.solve_h_tridiagonal_inplace(&mut b)?;
149 Ok(b)
150 }
151}
152
153impl<A> SolveTridiagonal<A, Ix2> for Tridiagonal<A>
154where
155 A: Scalar + Lapack,
156{
157 fn solve_tridiagonal(&self, b: &ArrayRef<A, Ix2>) -> Result<Array<A, Ix2>> {
158 let mut b = replicate(b);
159 self.solve_tridiagonal_inplace(&mut b)?;
160 Ok(b)
161 }
162 fn solve_tridiagonal_into<Sb: DataMut<Elem = A>>(
163 &self,
164 mut b: ArrayBase<Sb, Ix2>,
165 ) -> Result<ArrayBase<Sb, Ix2>> {
166 self.solve_tridiagonal_inplace(&mut b)?;
167 Ok(b)
168 }
169 fn solve_t_tridiagonal(&self, b: &ArrayRef<A, Ix2>) -> Result<Array<A, Ix2>> {
170 let mut b = replicate(b);
171 self.solve_t_tridiagonal_inplace(&mut b)?;
172 Ok(b)
173 }
174 fn solve_t_tridiagonal_into<Sb: DataMut<Elem = A>>(
175 &self,
176 mut b: ArrayBase<Sb, Ix2>,
177 ) -> Result<ArrayBase<Sb, Ix2>> {
178 self.solve_t_tridiagonal_inplace(&mut b)?;
179 Ok(b)
180 }
181 fn solve_h_tridiagonal(&self, b: &ArrayRef<A, Ix2>) -> Result<Array<A, Ix2>> {
182 let mut b = replicate(b);
183 self.solve_h_tridiagonal_inplace(&mut b)?;
184 Ok(b)
185 }
186 fn solve_h_tridiagonal_into<Sb: DataMut<Elem = A>>(
187 &self,
188 mut b: ArrayBase<Sb, Ix2>,
189 ) -> Result<ArrayBase<Sb, Ix2>> {
190 self.solve_h_tridiagonal_inplace(&mut b)?;
191 Ok(b)
192 }
193}
194
195impl<A> SolveTridiagonal<A, Ix2> for ArrayRef<A, Ix2>
196where
197 A: Scalar + Lapack,
198{
199 fn solve_tridiagonal(&self, b: &ArrayRef<A, Ix2>) -> Result<Array<A, Ix2>> {
200 let mut b = replicate(b);
201 self.solve_tridiagonal_inplace(&mut b)?;
202 Ok(b)
203 }
204 fn solve_tridiagonal_into<Sb: DataMut<Elem = A>>(
205 &self,
206 mut b: ArrayBase<Sb, Ix2>,
207 ) -> Result<ArrayBase<Sb, Ix2>> {
208 self.solve_tridiagonal_inplace(&mut b)?;
209 Ok(b)
210 }
211 fn solve_t_tridiagonal(&self, b: &ArrayRef<A, Ix2>) -> Result<Array<A, Ix2>> {
212 let mut b = replicate(b);
213 self.solve_t_tridiagonal_inplace(&mut b)?;
214 Ok(b)
215 }
216 fn solve_t_tridiagonal_into<Sb: DataMut<Elem = A>>(
217 &self,
218 mut b: ArrayBase<Sb, Ix2>,
219 ) -> Result<ArrayBase<Sb, Ix2>> {
220 self.solve_t_tridiagonal_inplace(&mut b)?;
221 Ok(b)
222 }
223 fn solve_h_tridiagonal(&self, b: &ArrayRef<A, Ix2>) -> Result<Array<A, Ix2>> {
224 let mut b = replicate(b);
225 self.solve_h_tridiagonal_inplace(&mut b)?;
226 Ok(b)
227 }
228 fn solve_h_tridiagonal_into<Sb: DataMut<Elem = A>>(
229 &self,
230 mut b: ArrayBase<Sb, Ix2>,
231 ) -> Result<ArrayBase<Sb, Ix2>> {
232 self.solve_h_tridiagonal_inplace(&mut b)?;
233 Ok(b)
234 }
235}
236
237impl<A> SolveTridiagonalInplace<A, Ix2> for LUFactorizedTridiagonal<A>
238where
239 A: Scalar + Lapack,
240{
241 fn solve_tridiagonal_inplace<'a>(
242 &self,
243 rhs: &'a mut ArrayRef<A, Ix2>,
244 ) -> Result<&'a mut ArrayRef<A, Ix2>> {
245 A::solve_tridiagonal(
246 self,
247 rhs.layout()?,
248 Transpose::No,
249 rhs.as_slice_mut().unwrap(),
250 )?;
251 Ok(rhs)
252 }
253 fn solve_t_tridiagonal_inplace<'a>(
254 &self,
255 rhs: &'a mut ArrayRef<A, Ix2>,
256 ) -> Result<&'a mut ArrayRef<A, Ix2>> {
257 A::solve_tridiagonal(
258 self,
259 rhs.layout()?,
260 Transpose::Transpose,
261 rhs.as_slice_mut().unwrap(),
262 )?;
263 Ok(rhs)
264 }
265 fn solve_h_tridiagonal_inplace<'a>(
266 &self,
267 rhs: &'a mut ArrayRef<A, Ix2>,
268 ) -> Result<&'a mut ArrayRef<A, Ix2>> {
269 A::solve_tridiagonal(
270 self,
271 rhs.layout()?,
272 Transpose::Hermite,
273 rhs.as_slice_mut().unwrap(),
274 )?;
275 Ok(rhs)
276 }
277}
278
279impl<A> SolveTridiagonalInplace<A, Ix2> for Tridiagonal<A>
280where
281 A: Scalar + Lapack,
282{
283 fn solve_tridiagonal_inplace<'a>(
284 &self,
285 rhs: &'a mut ArrayRef<A, Ix2>,
286 ) -> Result<&'a mut ArrayRef<A, Ix2>> {
287 let f = self.factorize_tridiagonal()?;
288 f.solve_tridiagonal_inplace(rhs)
289 }
290 fn solve_t_tridiagonal_inplace<'a>(
291 &self,
292 rhs: &'a mut ArrayRef<A, Ix2>,
293 ) -> Result<&'a mut ArrayRef<A, Ix2>> {
294 let f = self.factorize_tridiagonal()?;
295 f.solve_t_tridiagonal_inplace(rhs)
296 }
297 fn solve_h_tridiagonal_inplace<'a>(
298 &self,
299 rhs: &'a mut ArrayRef<A, Ix2>,
300 ) -> Result<&'a mut ArrayRef<A, Ix2>> {
301 let f = self.factorize_tridiagonal()?;
302 f.solve_h_tridiagonal_inplace(rhs)
303 }
304}
305
306impl<A> SolveTridiagonalInplace<A, Ix2> for ArrayRef<A, Ix2>
307where
308 A: Scalar + Lapack,
309{
310 fn solve_tridiagonal_inplace<'a>(
311 &self,
312 rhs: &'a mut ArrayRef<A, Ix2>,
313 ) -> Result<&'a mut ArrayRef<A, Ix2>> {
314 let f = self.factorize_tridiagonal()?;
315 f.solve_tridiagonal_inplace(rhs)
316 }
317 fn solve_t_tridiagonal_inplace<'a>(
318 &self,
319 rhs: &'a mut ArrayRef<A, Ix2>,
320 ) -> Result<&'a mut ArrayRef<A, Ix2>> {
321 let f = self.factorize_tridiagonal()?;
322 f.solve_t_tridiagonal_inplace(rhs)
323 }
324 fn solve_h_tridiagonal_inplace<'a>(
325 &self,
326 rhs: &'a mut ArrayRef<A, Ix2>,
327 ) -> Result<&'a mut ArrayRef<A, Ix2>> {
328 let f = self.factorize_tridiagonal()?;
329 f.solve_h_tridiagonal_inplace(rhs)
330 }
331}
332
333impl<A> SolveTridiagonal<A, Ix1> for LUFactorizedTridiagonal<A>
334where
335 A: Scalar + Lapack,
336{
337 fn solve_tridiagonal(&self, b: &ArrayRef<A, Ix1>) -> Result<Array<A, Ix1>> {
338 let b = b.to_owned();
339 self.solve_tridiagonal_into(b)
340 }
341 fn solve_tridiagonal_into<S: DataMut<Elem = A>>(
342 &self,
343 b: ArrayBase<S, Ix1>,
344 ) -> Result<ArrayBase<S, Ix1>> {
345 let b = into_col(b);
346 let b = self.solve_tridiagonal_into(b)?;
347 Ok(flatten(b))
348 }
349 fn solve_t_tridiagonal(&self, b: &ArrayRef<A, Ix1>) -> Result<Array<A, Ix1>> {
350 let b = b.to_owned();
351 self.solve_t_tridiagonal_into(b)
352 }
353 fn solve_t_tridiagonal_into<S: DataMut<Elem = A>>(
354 &self,
355 b: ArrayBase<S, Ix1>,
356 ) -> Result<ArrayBase<S, Ix1>> {
357 let b = into_col(b);
358 let b = self.solve_t_tridiagonal_into(b)?;
359 Ok(flatten(b))
360 }
361 fn solve_h_tridiagonal(&self, b: &ArrayRef<A, Ix1>) -> Result<Array<A, Ix1>> {
362 let b = b.to_owned();
363 self.solve_h_tridiagonal_into(b)
364 }
365 fn solve_h_tridiagonal_into<S: DataMut<Elem = A>>(
366 &self,
367 b: ArrayBase<S, Ix1>,
368 ) -> Result<ArrayBase<S, Ix1>> {
369 let b = into_col(b);
370 let b = self.solve_h_tridiagonal_into(b)?;
371 Ok(flatten(b))
372 }
373}
374
375impl<A> SolveTridiagonal<A, Ix1> for Tridiagonal<A>
376where
377 A: Scalar + Lapack,
378{
379 fn solve_tridiagonal(&self, b: &ArrayRef<A, Ix1>) -> Result<Array<A, Ix1>> {
380 let b = b.to_owned();
381 self.solve_tridiagonal_into(b)
382 }
383 fn solve_tridiagonal_into<Sb: DataMut<Elem = A>>(
384 &self,
385 b: ArrayBase<Sb, Ix1>,
386 ) -> Result<ArrayBase<Sb, Ix1>> {
387 let b = into_col(b);
388 let f = self.factorize_tridiagonal()?;
389 let b = f.solve_tridiagonal_into(b)?;
390 Ok(flatten(b))
391 }
392 fn solve_t_tridiagonal(&self, b: &ArrayRef<A, Ix1>) -> Result<Array<A, Ix1>> {
393 let b = b.to_owned();
394 self.solve_t_tridiagonal_into(b)
395 }
396 fn solve_t_tridiagonal_into<Sb: DataMut<Elem = A>>(
397 &self,
398 b: ArrayBase<Sb, Ix1>,
399 ) -> Result<ArrayBase<Sb, Ix1>> {
400 let b = into_col(b);
401 let f = self.factorize_tridiagonal()?;
402 let b = f.solve_t_tridiagonal_into(b)?;
403 Ok(flatten(b))
404 }
405 fn solve_h_tridiagonal(&self, b: &ArrayRef<A, Ix1>) -> Result<Array<A, Ix1>> {
406 let b = b.to_owned();
407 self.solve_h_tridiagonal_into(b)
408 }
409 fn solve_h_tridiagonal_into<Sb: DataMut<Elem = A>>(
410 &self,
411 b: ArrayBase<Sb, Ix1>,
412 ) -> Result<ArrayBase<Sb, Ix1>> {
413 let b = into_col(b);
414 let f = self.factorize_tridiagonal()?;
415 let b = f.solve_h_tridiagonal_into(b)?;
416 Ok(flatten(b))
417 }
418}
419
420impl<A> SolveTridiagonal<A, Ix1> for ArrayRef<A, Ix2>
421where
422 A: Scalar + Lapack,
423{
424 fn solve_tridiagonal(&self, b: &ArrayRef<A, Ix1>) -> Result<Array<A, Ix1>> {
425 let b = b.to_owned();
426 self.solve_tridiagonal_into(b)
427 }
428 fn solve_tridiagonal_into<Sb: DataMut<Elem = A>>(
429 &self,
430 b: ArrayBase<Sb, Ix1>,
431 ) -> Result<ArrayBase<Sb, Ix1>> {
432 let b = into_col(b);
433 let f = self.factorize_tridiagonal()?;
434 let b = f.solve_tridiagonal_into(b)?;
435 Ok(flatten(b))
436 }
437 fn solve_t_tridiagonal(&self, b: &ArrayRef<A, Ix1>) -> Result<Array<A, Ix1>> {
438 let b = b.to_owned();
439 self.solve_t_tridiagonal_into(b)
440 }
441 fn solve_t_tridiagonal_into<Sb: DataMut<Elem = A>>(
442 &self,
443 b: ArrayBase<Sb, Ix1>,
444 ) -> Result<ArrayBase<Sb, Ix1>> {
445 let b = into_col(b);
446 let f = self.factorize_tridiagonal()?;
447 let b = f.solve_t_tridiagonal_into(b)?;
448 Ok(flatten(b))
449 }
450 fn solve_h_tridiagonal(&self, b: &ArrayRef<A, Ix1>) -> Result<Array<A, Ix1>> {
451 let b = b.to_owned();
452 self.solve_h_tridiagonal_into(b)
453 }
454 fn solve_h_tridiagonal_into<Sb: DataMut<Elem = A>>(
455 &self,
456 b: ArrayBase<Sb, Ix1>,
457 ) -> Result<ArrayBase<Sb, Ix1>> {
458 let b = into_col(b);
459 let f = self.factorize_tridiagonal()?;
460 let b = f.solve_h_tridiagonal_into(b)?;
461 Ok(flatten(b))
462 }
463}
464
465pub trait FactorizeTridiagonal<A: Scalar> {
467 fn factorize_tridiagonal(&self) -> Result<LUFactorizedTridiagonal<A>>;
470}
471
472pub trait FactorizeTridiagonalInto<A: Scalar> {
474 fn factorize_tridiagonal_into(self) -> Result<LUFactorizedTridiagonal<A>>;
477}
478
479impl<A> FactorizeTridiagonalInto<A> for Tridiagonal<A>
480where
481 A: Scalar + Lapack,
482{
483 fn factorize_tridiagonal_into(self) -> Result<LUFactorizedTridiagonal<A>> {
484 Ok(A::lu_tridiagonal(self)?)
485 }
486}
487
488impl<A> FactorizeTridiagonal<A> for Tridiagonal<A>
489where
490 A: Scalar + Lapack,
491{
492 fn factorize_tridiagonal(&self) -> Result<LUFactorizedTridiagonal<A>> {
493 let a = self.clone();
494 Ok(A::lu_tridiagonal(a)?)
495 }
496}
497
498impl<A> FactorizeTridiagonal<A> for ArrayRef<A, Ix2>
499where
500 A: Scalar + Lapack,
501{
502 fn factorize_tridiagonal(&self) -> Result<LUFactorizedTridiagonal<A>> {
503 let a = self.extract_tridiagonal()?;
504 Ok(A::lu_tridiagonal(a)?)
505 }
506}
507
508fn rec_rel<A: Scalar>(tridiag: &Tridiagonal<A>) -> Vec<A> {
520 let n = tridiag.d.len();
521 let mut f = Vec::with_capacity(n + 1);
522 f.push(One::one());
523 f.push(tridiag.d[0]);
524 for i in 1..n {
525 f.push(tridiag.d[i] * f[i] - tridiag.dl[i - 1] * tridiag.du[i - 1] * f[i - 1]);
526 }
527 f
528}
529
530pub trait DeterminantTridiagonal<A: Scalar> {
532 fn det_tridiagonal(&self) -> Result<A>;
537}
538
539impl<A> DeterminantTridiagonal<A> for Tridiagonal<A>
540where
541 A: Scalar,
542{
543 fn det_tridiagonal(&self) -> Result<A> {
544 let n = self.d.len();
545 Ok(rec_rel(self)[n])
546 }
547}
548
549impl<A> DeterminantTridiagonal<A> for ArrayRef<A, Ix2>
550where
551 A: Scalar + Lapack,
552{
553 fn det_tridiagonal(&self) -> Result<A> {
554 let tridiag = self.extract_tridiagonal()?;
555 let n = tridiag.d.len();
556 Ok(rec_rel(&tridiag)[n])
557 }
558}
559
560pub trait ReciprocalConditionNumTridiagonal<A: Scalar> {
562 fn rcond_tridiagonal(&self) -> Result<A::Real>;
572}
573
574pub trait ReciprocalConditionNumTridiagonalInto<A: Scalar> {
576 fn rcond_tridiagonal_into(self) -> Result<A::Real>;
586}
587
588impl<A> ReciprocalConditionNumTridiagonal<A> for LUFactorizedTridiagonal<A>
589where
590 A: Scalar + Lapack,
591{
592 fn rcond_tridiagonal(&self) -> Result<A::Real> {
593 Ok(A::rcond_tridiagonal(self)?)
594 }
595}
596
597impl<A> ReciprocalConditionNumTridiagonalInto<A> for LUFactorizedTridiagonal<A>
598where
599 A: Scalar + Lapack,
600{
601 fn rcond_tridiagonal_into(self) -> Result<A::Real> {
602 self.rcond_tridiagonal()
603 }
604}
605
606impl<A> ReciprocalConditionNumTridiagonal<A> for ArrayRef<A, Ix2>
607where
608 A: Scalar + Lapack,
609{
610 fn rcond_tridiagonal(&self) -> Result<A::Real> {
611 self.factorize_tridiagonal()?.rcond_tridiagonal_into()
612 }
613}