1use super::convert::*;
6use super::error::*;
7use super::layout::*;
8use cauchy::Scalar;
9use lax::*;
10use ndarray::*;
11use num_traits::One;
12
13pub use lax::{LUFactorizedTridiagonal, Tridiagonal};
14
15pub trait ExtractTridiagonal<A: Scalar> {
17 fn extract_tridiagonal(&self) -> Result<Tridiagonal<A>>;
24}
25
26impl<A, S> ExtractTridiagonal<A> for ArrayBase<S, Ix2>
27where
28 A: Scalar + Lapack,
29 S: Data<Elem = A>,
30{
31 fn extract_tridiagonal(&self) -> Result<Tridiagonal<A>> {
32 let l = self.square_layout()?;
33 let (n, _) = l.size();
34 if n < 2 {
35 return Err(LinalgError::NotStandardShape {
36 obj: "Tridiagonal",
37 rows: 1,
38 cols: 1,
39 });
40 }
41
42 let dl = self.slice(s![1..n, 0..n - 1]).diag().to_vec();
43 let d = self.diag().to_vec();
44 let du = self.slice(s![0..n - 1, 1..n]).diag().to_vec();
45 Ok(Tridiagonal { l, dl, d, du })
46 }
47}
48
49pub trait SolveTridiagonal<A: Scalar, D: Dimension> {
50 fn solve_tridiagonal<S: Data<Elem = A>>(&self, b: &ArrayBase<S, D>) -> Result<Array<A, D>>;
54 fn solve_tridiagonal_into<S: DataMut<Elem = A>>(
58 &self,
59 b: ArrayBase<S, D>,
60 ) -> Result<ArrayBase<S, D>>;
61 fn solve_t_tridiagonal<S: Data<Elem = A>>(&self, b: &ArrayBase<S, D>) -> Result<Array<A, D>>;
65 fn solve_t_tridiagonal_into<S: DataMut<Elem = A>>(
69 &self,
70 b: ArrayBase<S, D>,
71 ) -> Result<ArrayBase<S, D>>;
72 fn solve_h_tridiagonal<S: Data<Elem = A>>(&self, b: &ArrayBase<S, D>) -> Result<Array<A, D>>;
76 fn solve_h_tridiagonal_into<S: DataMut<Elem = A>>(
80 &self,
81 b: ArrayBase<S, D>,
82 ) -> Result<ArrayBase<S, D>>;
83}
84
85pub trait SolveTridiagonalInplace<A: Scalar, D: Dimension> {
86 fn solve_tridiagonal_inplace<'a, S: DataMut<Elem = A>>(
91 &self,
92 b: &'a mut ArrayBase<S, D>,
93 ) -> Result<&'a mut ArrayBase<S, D>>;
94 fn solve_t_tridiagonal_inplace<'a, S: DataMut<Elem = A>>(
99 &self,
100 b: &'a mut ArrayBase<S, D>,
101 ) -> Result<&'a mut ArrayBase<S, D>>;
102 fn solve_h_tridiagonal_inplace<'a, S: DataMut<Elem = A>>(
107 &self,
108 b: &'a mut ArrayBase<S, D>,
109 ) -> Result<&'a mut ArrayBase<S, D>>;
110}
111
112impl<A> SolveTridiagonal<A, Ix2> for LUFactorizedTridiagonal<A>
113where
114 A: Scalar + Lapack,
115{
116 fn solve_tridiagonal<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix2>) -> Result<Array<A, Ix2>> {
117 let mut b = replicate(b);
118 self.solve_tridiagonal_inplace(&mut b)?;
119 Ok(b)
120 }
121 fn solve_tridiagonal_into<S: DataMut<Elem = A>>(
122 &self,
123 mut b: ArrayBase<S, Ix2>,
124 ) -> Result<ArrayBase<S, Ix2>> {
125 self.solve_tridiagonal_inplace(&mut b)?;
126 Ok(b)
127 }
128 fn solve_t_tridiagonal<S: Data<Elem = A>>(
129 &self,
130 b: &ArrayBase<S, Ix2>,
131 ) -> Result<Array<A, Ix2>> {
132 let mut b = replicate(b);
133 self.solve_t_tridiagonal_inplace(&mut b)?;
134 Ok(b)
135 }
136 fn solve_t_tridiagonal_into<S: DataMut<Elem = A>>(
137 &self,
138 mut b: ArrayBase<S, Ix2>,
139 ) -> Result<ArrayBase<S, Ix2>> {
140 self.solve_t_tridiagonal_inplace(&mut b)?;
141 Ok(b)
142 }
143 fn solve_h_tridiagonal<S: Data<Elem = A>>(
144 &self,
145 b: &ArrayBase<S, Ix2>,
146 ) -> Result<Array<A, Ix2>> {
147 let mut b = replicate(b);
148 self.solve_h_tridiagonal_inplace(&mut b)?;
149 Ok(b)
150 }
151 fn solve_h_tridiagonal_into<S: DataMut<Elem = A>>(
152 &self,
153 mut b: ArrayBase<S, Ix2>,
154 ) -> Result<ArrayBase<S, Ix2>> {
155 self.solve_h_tridiagonal_inplace(&mut b)?;
156 Ok(b)
157 }
158}
159
160impl<A> SolveTridiagonal<A, Ix2> for Tridiagonal<A>
161where
162 A: Scalar + Lapack,
163{
164 fn solve_tridiagonal<Sb: Data<Elem = A>>(
165 &self,
166 b: &ArrayBase<Sb, Ix2>,
167 ) -> Result<Array<A, Ix2>> {
168 let mut b = replicate(b);
169 self.solve_tridiagonal_inplace(&mut b)?;
170 Ok(b)
171 }
172 fn solve_tridiagonal_into<Sb: DataMut<Elem = A>>(
173 &self,
174 mut b: ArrayBase<Sb, Ix2>,
175 ) -> Result<ArrayBase<Sb, Ix2>> {
176 self.solve_tridiagonal_inplace(&mut b)?;
177 Ok(b)
178 }
179 fn solve_t_tridiagonal<Sb: Data<Elem = A>>(
180 &self,
181 b: &ArrayBase<Sb, Ix2>,
182 ) -> Result<Array<A, Ix2>> {
183 let mut b = replicate(b);
184 self.solve_t_tridiagonal_inplace(&mut b)?;
185 Ok(b)
186 }
187 fn solve_t_tridiagonal_into<Sb: DataMut<Elem = A>>(
188 &self,
189 mut b: ArrayBase<Sb, Ix2>,
190 ) -> Result<ArrayBase<Sb, Ix2>> {
191 self.solve_t_tridiagonal_inplace(&mut b)?;
192 Ok(b)
193 }
194 fn solve_h_tridiagonal<Sb: Data<Elem = A>>(
195 &self,
196 b: &ArrayBase<Sb, Ix2>,
197 ) -> Result<Array<A, Ix2>> {
198 let mut b = replicate(b);
199 self.solve_h_tridiagonal_inplace(&mut b)?;
200 Ok(b)
201 }
202 fn solve_h_tridiagonal_into<Sb: DataMut<Elem = A>>(
203 &self,
204 mut b: ArrayBase<Sb, Ix2>,
205 ) -> Result<ArrayBase<Sb, Ix2>> {
206 self.solve_h_tridiagonal_inplace(&mut b)?;
207 Ok(b)
208 }
209}
210
211impl<A, S> SolveTridiagonal<A, Ix2> for ArrayBase<S, Ix2>
212where
213 A: Scalar + Lapack,
214 S: Data<Elem = A>,
215{
216 fn solve_tridiagonal<Sb: Data<Elem = A>>(
217 &self,
218 b: &ArrayBase<Sb, Ix2>,
219 ) -> Result<Array<A, Ix2>> {
220 let mut b = replicate(b);
221 self.solve_tridiagonal_inplace(&mut b)?;
222 Ok(b)
223 }
224 fn solve_tridiagonal_into<Sb: DataMut<Elem = A>>(
225 &self,
226 mut b: ArrayBase<Sb, Ix2>,
227 ) -> Result<ArrayBase<Sb, Ix2>> {
228 self.solve_tridiagonal_inplace(&mut b)?;
229 Ok(b)
230 }
231 fn solve_t_tridiagonal<Sb: Data<Elem = A>>(
232 &self,
233 b: &ArrayBase<Sb, Ix2>,
234 ) -> Result<Array<A, Ix2>> {
235 let mut b = replicate(b);
236 self.solve_t_tridiagonal_inplace(&mut b)?;
237 Ok(b)
238 }
239 fn solve_t_tridiagonal_into<Sb: DataMut<Elem = A>>(
240 &self,
241 mut b: ArrayBase<Sb, Ix2>,
242 ) -> Result<ArrayBase<Sb, Ix2>> {
243 self.solve_t_tridiagonal_inplace(&mut b)?;
244 Ok(b)
245 }
246 fn solve_h_tridiagonal<Sb: Data<Elem = A>>(
247 &self,
248 b: &ArrayBase<Sb, Ix2>,
249 ) -> Result<Array<A, Ix2>> {
250 let mut b = replicate(b);
251 self.solve_h_tridiagonal_inplace(&mut b)?;
252 Ok(b)
253 }
254 fn solve_h_tridiagonal_into<Sb: DataMut<Elem = A>>(
255 &self,
256 mut b: ArrayBase<Sb, Ix2>,
257 ) -> Result<ArrayBase<Sb, Ix2>> {
258 self.solve_h_tridiagonal_inplace(&mut b)?;
259 Ok(b)
260 }
261}
262
263impl<A> SolveTridiagonalInplace<A, Ix2> for LUFactorizedTridiagonal<A>
264where
265 A: Scalar + Lapack,
266{
267 fn solve_tridiagonal_inplace<'a, Sb>(
268 &self,
269 rhs: &'a mut ArrayBase<Sb, Ix2>,
270 ) -> Result<&'a mut ArrayBase<Sb, Ix2>>
271 where
272 Sb: DataMut<Elem = A>,
273 {
274 A::solve_tridiagonal(
275 self,
276 rhs.layout()?,
277 Transpose::No,
278 rhs.as_slice_mut().unwrap(),
279 )?;
280 Ok(rhs)
281 }
282 fn solve_t_tridiagonal_inplace<'a, Sb>(
283 &self,
284 rhs: &'a mut ArrayBase<Sb, Ix2>,
285 ) -> Result<&'a mut ArrayBase<Sb, Ix2>>
286 where
287 Sb: DataMut<Elem = A>,
288 {
289 A::solve_tridiagonal(
290 self,
291 rhs.layout()?,
292 Transpose::Transpose,
293 rhs.as_slice_mut().unwrap(),
294 )?;
295 Ok(rhs)
296 }
297 fn solve_h_tridiagonal_inplace<'a, Sb>(
298 &self,
299 rhs: &'a mut ArrayBase<Sb, Ix2>,
300 ) -> Result<&'a mut ArrayBase<Sb, Ix2>>
301 where
302 Sb: DataMut<Elem = A>,
303 {
304 A::solve_tridiagonal(
305 self,
306 rhs.layout()?,
307 Transpose::Hermite,
308 rhs.as_slice_mut().unwrap(),
309 )?;
310 Ok(rhs)
311 }
312}
313
314impl<A> SolveTridiagonalInplace<A, Ix2> for Tridiagonal<A>
315where
316 A: Scalar + Lapack,
317{
318 fn solve_tridiagonal_inplace<'a, Sb>(
319 &self,
320 rhs: &'a mut ArrayBase<Sb, Ix2>,
321 ) -> Result<&'a mut ArrayBase<Sb, Ix2>>
322 where
323 Sb: DataMut<Elem = A>,
324 {
325 let f = self.factorize_tridiagonal()?;
326 f.solve_tridiagonal_inplace(rhs)
327 }
328 fn solve_t_tridiagonal_inplace<'a, Sb>(
329 &self,
330 rhs: &'a mut ArrayBase<Sb, Ix2>,
331 ) -> Result<&'a mut ArrayBase<Sb, Ix2>>
332 where
333 Sb: DataMut<Elem = A>,
334 {
335 let f = self.factorize_tridiagonal()?;
336 f.solve_t_tridiagonal_inplace(rhs)
337 }
338 fn solve_h_tridiagonal_inplace<'a, Sb>(
339 &self,
340 rhs: &'a mut ArrayBase<Sb, Ix2>,
341 ) -> Result<&'a mut ArrayBase<Sb, Ix2>>
342 where
343 Sb: DataMut<Elem = A>,
344 {
345 let f = self.factorize_tridiagonal()?;
346 f.solve_h_tridiagonal_inplace(rhs)
347 }
348}
349
350impl<A, S> SolveTridiagonalInplace<A, Ix2> for ArrayBase<S, Ix2>
351where
352 A: Scalar + Lapack,
353 S: Data<Elem = A>,
354{
355 fn solve_tridiagonal_inplace<'a, Sb>(
356 &self,
357 rhs: &'a mut ArrayBase<Sb, Ix2>,
358 ) -> Result<&'a mut ArrayBase<Sb, Ix2>>
359 where
360 Sb: DataMut<Elem = A>,
361 {
362 let f = self.factorize_tridiagonal()?;
363 f.solve_tridiagonal_inplace(rhs)
364 }
365 fn solve_t_tridiagonal_inplace<'a, Sb>(
366 &self,
367 rhs: &'a mut ArrayBase<Sb, Ix2>,
368 ) -> Result<&'a mut ArrayBase<Sb, Ix2>>
369 where
370 Sb: DataMut<Elem = A>,
371 {
372 let f = self.factorize_tridiagonal()?;
373 f.solve_t_tridiagonal_inplace(rhs)
374 }
375 fn solve_h_tridiagonal_inplace<'a, Sb>(
376 &self,
377 rhs: &'a mut ArrayBase<Sb, Ix2>,
378 ) -> Result<&'a mut ArrayBase<Sb, Ix2>>
379 where
380 Sb: DataMut<Elem = A>,
381 {
382 let f = self.factorize_tridiagonal()?;
383 f.solve_h_tridiagonal_inplace(rhs)
384 }
385}
386
387impl<A> SolveTridiagonal<A, Ix1> for LUFactorizedTridiagonal<A>
388where
389 A: Scalar + Lapack,
390{
391 fn solve_tridiagonal<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array<A, Ix1>> {
392 let b = b.to_owned();
393 self.solve_tridiagonal_into(b)
394 }
395 fn solve_tridiagonal_into<S: DataMut<Elem = A>>(
396 &self,
397 b: ArrayBase<S, Ix1>,
398 ) -> Result<ArrayBase<S, Ix1>> {
399 let b = into_col(b);
400 let b = self.solve_tridiagonal_into(b)?;
401 Ok(flatten(b))
402 }
403 fn solve_t_tridiagonal<S: Data<Elem = A>>(
404 &self,
405 b: &ArrayBase<S, Ix1>,
406 ) -> Result<Array<A, Ix1>> {
407 let b = b.to_owned();
408 self.solve_t_tridiagonal_into(b)
409 }
410 fn solve_t_tridiagonal_into<S: DataMut<Elem = A>>(
411 &self,
412 b: ArrayBase<S, Ix1>,
413 ) -> Result<ArrayBase<S, Ix1>> {
414 let b = into_col(b);
415 let b = self.solve_t_tridiagonal_into(b)?;
416 Ok(flatten(b))
417 }
418 fn solve_h_tridiagonal<S: Data<Elem = A>>(
419 &self,
420 b: &ArrayBase<S, Ix1>,
421 ) -> Result<Array<A, Ix1>> {
422 let b = b.to_owned();
423 self.solve_h_tridiagonal_into(b)
424 }
425 fn solve_h_tridiagonal_into<S: DataMut<Elem = A>>(
426 &self,
427 b: ArrayBase<S, Ix1>,
428 ) -> Result<ArrayBase<S, Ix1>> {
429 let b = into_col(b);
430 let b = self.solve_h_tridiagonal_into(b)?;
431 Ok(flatten(b))
432 }
433}
434
435impl<A> SolveTridiagonal<A, Ix1> for Tridiagonal<A>
436where
437 A: Scalar + Lapack,
438{
439 fn solve_tridiagonal<Sb: Data<Elem = A>>(
440 &self,
441 b: &ArrayBase<Sb, Ix1>,
442 ) -> Result<Array<A, Ix1>> {
443 let b = b.to_owned();
444 self.solve_tridiagonal_into(b)
445 }
446 fn solve_tridiagonal_into<Sb: DataMut<Elem = A>>(
447 &self,
448 b: ArrayBase<Sb, Ix1>,
449 ) -> Result<ArrayBase<Sb, Ix1>> {
450 let b = into_col(b);
451 let f = self.factorize_tridiagonal()?;
452 let b = f.solve_tridiagonal_into(b)?;
453 Ok(flatten(b))
454 }
455 fn solve_t_tridiagonal<Sb: Data<Elem = A>>(
456 &self,
457 b: &ArrayBase<Sb, Ix1>,
458 ) -> Result<Array<A, Ix1>> {
459 let b = b.to_owned();
460 self.solve_t_tridiagonal_into(b)
461 }
462 fn solve_t_tridiagonal_into<Sb: DataMut<Elem = A>>(
463 &self,
464 b: ArrayBase<Sb, Ix1>,
465 ) -> Result<ArrayBase<Sb, Ix1>> {
466 let b = into_col(b);
467 let f = self.factorize_tridiagonal()?;
468 let b = f.solve_t_tridiagonal_into(b)?;
469 Ok(flatten(b))
470 }
471 fn solve_h_tridiagonal<Sb: Data<Elem = A>>(
472 &self,
473 b: &ArrayBase<Sb, Ix1>,
474 ) -> Result<Array<A, Ix1>> {
475 let b = b.to_owned();
476 self.solve_h_tridiagonal_into(b)
477 }
478 fn solve_h_tridiagonal_into<Sb: DataMut<Elem = A>>(
479 &self,
480 b: ArrayBase<Sb, Ix1>,
481 ) -> Result<ArrayBase<Sb, Ix1>> {
482 let b = into_col(b);
483 let f = self.factorize_tridiagonal()?;
484 let b = f.solve_h_tridiagonal_into(b)?;
485 Ok(flatten(b))
486 }
487}
488
489impl<A, S> SolveTridiagonal<A, Ix1> for ArrayBase<S, Ix2>
490where
491 A: Scalar + Lapack,
492 S: Data<Elem = A>,
493{
494 fn solve_tridiagonal<Sb: Data<Elem = A>>(
495 &self,
496 b: &ArrayBase<Sb, Ix1>,
497 ) -> Result<Array<A, Ix1>> {
498 let b = b.to_owned();
499 self.solve_tridiagonal_into(b)
500 }
501 fn solve_tridiagonal_into<Sb: DataMut<Elem = A>>(
502 &self,
503 b: ArrayBase<Sb, Ix1>,
504 ) -> Result<ArrayBase<Sb, Ix1>> {
505 let b = into_col(b);
506 let f = self.factorize_tridiagonal()?;
507 let b = f.solve_tridiagonal_into(b)?;
508 Ok(flatten(b))
509 }
510 fn solve_t_tridiagonal<Sb: Data<Elem = A>>(
511 &self,
512 b: &ArrayBase<Sb, Ix1>,
513 ) -> Result<Array<A, Ix1>> {
514 let b = b.to_owned();
515 self.solve_t_tridiagonal_into(b)
516 }
517 fn solve_t_tridiagonal_into<Sb: DataMut<Elem = A>>(
518 &self,
519 b: ArrayBase<Sb, Ix1>,
520 ) -> Result<ArrayBase<Sb, Ix1>> {
521 let b = into_col(b);
522 let f = self.factorize_tridiagonal()?;
523 let b = f.solve_t_tridiagonal_into(b)?;
524 Ok(flatten(b))
525 }
526 fn solve_h_tridiagonal<Sb: Data<Elem = A>>(
527 &self,
528 b: &ArrayBase<Sb, Ix1>,
529 ) -> Result<Array<A, Ix1>> {
530 let b = b.to_owned();
531 self.solve_h_tridiagonal_into(b)
532 }
533 fn solve_h_tridiagonal_into<Sb: DataMut<Elem = A>>(
534 &self,
535 b: ArrayBase<Sb, Ix1>,
536 ) -> Result<ArrayBase<Sb, Ix1>> {
537 let b = into_col(b);
538 let f = self.factorize_tridiagonal()?;
539 let b = f.solve_h_tridiagonal_into(b)?;
540 Ok(flatten(b))
541 }
542}
543
544pub trait FactorizeTridiagonal<A: Scalar> {
546 fn factorize_tridiagonal(&self) -> Result<LUFactorizedTridiagonal<A>>;
549}
550
551pub trait FactorizeTridiagonalInto<A: Scalar> {
553 fn factorize_tridiagonal_into(self) -> Result<LUFactorizedTridiagonal<A>>;
556}
557
558impl<A> FactorizeTridiagonalInto<A> for Tridiagonal<A>
559where
560 A: Scalar + Lapack,
561{
562 fn factorize_tridiagonal_into(self) -> Result<LUFactorizedTridiagonal<A>> {
563 Ok(A::lu_tridiagonal(self)?)
564 }
565}
566
567impl<A> FactorizeTridiagonal<A> for Tridiagonal<A>
568where
569 A: Scalar + Lapack,
570{
571 fn factorize_tridiagonal(&self) -> Result<LUFactorizedTridiagonal<A>> {
572 let a = self.clone();
573 Ok(A::lu_tridiagonal(a)?)
574 }
575}
576
577impl<A, S> FactorizeTridiagonal<A> for ArrayBase<S, Ix2>
578where
579 A: Scalar + Lapack,
580 S: Data<Elem = A>,
581{
582 fn factorize_tridiagonal(&self) -> Result<LUFactorizedTridiagonal<A>> {
583 let a = self.extract_tridiagonal()?;
584 Ok(A::lu_tridiagonal(a)?)
585 }
586}
587
588fn rec_rel<A: Scalar>(tridiag: &Tridiagonal<A>) -> Vec<A> {
600 let n = tridiag.d.len();
601 let mut f = Vec::with_capacity(n + 1);
602 f.push(One::one());
603 f.push(tridiag.d[0]);
604 for i in 1..n {
605 f.push(tridiag.d[i] * f[i] - tridiag.dl[i - 1] * tridiag.du[i - 1] * f[i - 1]);
606 }
607 f
608}
609
610pub trait DeterminantTridiagonal<A: Scalar> {
612 fn det_tridiagonal(&self) -> Result<A>;
617}
618
619impl<A> DeterminantTridiagonal<A> for Tridiagonal<A>
620where
621 A: Scalar,
622{
623 fn det_tridiagonal(&self) -> Result<A> {
624 let n = self.d.len();
625 Ok(rec_rel(self)[n])
626 }
627}
628
629impl<A, S> DeterminantTridiagonal<A> for ArrayBase<S, Ix2>
630where
631 A: Scalar + Lapack,
632 S: Data<Elem = A>,
633{
634 fn det_tridiagonal(&self) -> Result<A> {
635 let tridiag = self.extract_tridiagonal()?;
636 let n = tridiag.d.len();
637 Ok(rec_rel(&tridiag)[n])
638 }
639}
640
641pub trait ReciprocalConditionNumTridiagonal<A: Scalar> {
643 fn rcond_tridiagonal(&self) -> Result<A::Real>;
653}
654
655pub trait ReciprocalConditionNumTridiagonalInto<A: Scalar> {
657 fn rcond_tridiagonal_into(self) -> Result<A::Real>;
667}
668
669impl<A> ReciprocalConditionNumTridiagonal<A> for LUFactorizedTridiagonal<A>
670where
671 A: Scalar + Lapack,
672{
673 fn rcond_tridiagonal(&self) -> Result<A::Real> {
674 Ok(A::rcond_tridiagonal(self)?)
675 }
676}
677
678impl<A> ReciprocalConditionNumTridiagonalInto<A> for LUFactorizedTridiagonal<A>
679where
680 A: Scalar + Lapack,
681{
682 fn rcond_tridiagonal_into(self) -> Result<A::Real> {
683 self.rcond_tridiagonal()
684 }
685}
686
687impl<A, S> ReciprocalConditionNumTridiagonal<A> for ArrayBase<S, Ix2>
688where
689 A: Scalar + Lapack,
690 S: Data<Elem = A>,
691{
692 fn rcond_tridiagonal(&self) -> Result<A::Real> {
693 self.factorize_tridiagonal()?.rcond_tridiagonal_into()
694 }
695}