lax/tridiagonal/
matrix.rs

1use crate::layout::*;
2use cauchy::*;
3use std::ops::{Index, IndexMut};
4
5/// Represents a tridiagonal matrix as 3 one-dimensional vectors.
6///
7/// ```text
8/// [d0, u1,  0,   ...,       0,
9///  l1, d1, u2,            ...,
10///   0, l2, d2,
11///  ...           ...,  u{n-1},
12///   0,  ...,  l{n-1},  d{n-1},]
13/// ```
14#[derive(Clone, PartialEq, Eq)]
15pub struct Tridiagonal<A: Scalar> {
16    /// layout of raw matrix
17    pub l: MatrixLayout,
18    /// (n-1) sub-diagonal elements of matrix.
19    pub dl: Vec<A>,
20    /// (n) diagonal elements of matrix.
21    pub d: Vec<A>,
22    /// (n-1) super-diagonal elements of matrix.
23    pub du: Vec<A>,
24}
25
26impl<A: Scalar> Index<(i32, i32)> for Tridiagonal<A> {
27    type Output = A;
28    #[inline]
29    fn index(&self, (row, col): (i32, i32)) -> &A {
30        let (n, _) = self.l.size();
31        assert!(
32            std::cmp::max(row, col) < n,
33            "ndarray: index {:?} is out of bounds for array of shape {}",
34            [row, col],
35            n
36        );
37        match row - col {
38            0 => &self.d[row as usize],
39            1 => &self.dl[col as usize],
40            -1 => &self.du[row as usize],
41            _ => panic!(
42                "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element",
43                [row, col]
44            ),
45        }
46    }
47}
48
49impl<A: Scalar> Index<[i32; 2]> for Tridiagonal<A> {
50    type Output = A;
51    #[inline]
52    fn index(&self, [row, col]: [i32; 2]) -> &A {
53        &self[(row, col)]
54    }
55}
56
57impl<A: Scalar> IndexMut<(i32, i32)> for Tridiagonal<A> {
58    #[inline]
59    fn index_mut(&mut self, (row, col): (i32, i32)) -> &mut A {
60        let (n, _) = self.l.size();
61        assert!(
62            std::cmp::max(row, col) < n,
63            "ndarray: index {:?} is out of bounds for array of shape {}",
64            [row, col],
65            n
66        );
67        match row - col {
68            0 => &mut self.d[row as usize],
69            1 => &mut self.dl[col as usize],
70            -1 => &mut self.du[row as usize],
71            _ => panic!(
72                "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element",
73                [row, col]
74            ),
75        }
76    }
77}
78
79impl<A: Scalar> IndexMut<[i32; 2]> for Tridiagonal<A> {
80    #[inline]
81    fn index_mut(&mut self, [row, col]: [i32; 2]) -> &mut A {
82        &mut self[(row, col)]
83    }
84}