ndarray_linalg/
triangular.rs

1//! Methods for triangular matrices
2
3use lax::*;
4use ndarray::*;
5use num_traits::Zero;
6
7use super::convert::*;
8use super::error::*;
9use super::layout::*;
10use super::types::*;
11
12pub use lax::Diag;
13
14/// solve a triangular system with upper triangular matrix
15pub trait SolveTriangular<A, S, D>
16where
17    A: Scalar + Lapack,
18    S: Data<Elem = A>,
19    D: Dimension,
20{
21    fn solve_triangular(&self, uplo: UPLO, diag: Diag, b: &ArrayBase<S, D>) -> Result<Array<A, D>>;
22}
23
24/// solve a triangular system with upper triangular matrix
25pub trait SolveTriangularInto<S, D>
26where
27    S: DataMut,
28    D: Dimension,
29{
30    fn solve_triangular_into(
31        &self,
32        uplo: UPLO,
33        diag: Diag,
34        b: ArrayBase<S, D>,
35    ) -> Result<ArrayBase<S, D>>;
36}
37
38/// solve a triangular system with upper triangular matrix
39pub trait SolveTriangularInplace<S, D>
40where
41    S: DataMut,
42    D: Dimension,
43{
44    fn solve_triangular_inplace<'a>(
45        &self,
46        uplo: UPLO,
47        diag: Diag,
48        b: &'a mut ArrayBase<S, D>,
49    ) -> Result<&'a mut ArrayBase<S, D>>;
50}
51
52impl<A, Si, So> SolveTriangularInto<So, Ix2> for ArrayBase<Si, Ix2>
53where
54    A: Scalar + Lapack,
55    Si: Data<Elem = A>,
56    So: DataMut<Elem = A> + DataOwned,
57{
58    fn solve_triangular_into(
59        &self,
60        uplo: UPLO,
61        diag: Diag,
62        mut b: ArrayBase<So, Ix2>,
63    ) -> Result<ArrayBase<So, Ix2>> {
64        self.solve_triangular_inplace(uplo, diag, &mut b)?;
65        Ok(b)
66    }
67}
68
69impl<A, Si, So> SolveTriangularInplace<So, Ix2> for ArrayBase<Si, Ix2>
70where
71    A: Scalar + Lapack,
72    Si: Data<Elem = A>,
73    So: DataMut<Elem = A> + DataOwned,
74{
75    fn solve_triangular_inplace<'a>(
76        &self,
77        uplo: UPLO,
78        diag: Diag,
79        b: &'a mut ArrayBase<So, Ix2>,
80    ) -> Result<&'a mut ArrayBase<So, Ix2>> {
81        let la = self.layout()?;
82        let a_ = self.as_allocated()?;
83        let lb = b.layout()?;
84        if !la.same_order(&lb) {
85            transpose_data(b)?;
86        }
87        let lb = b.layout()?;
88        A::solve_triangular(la, lb, uplo, diag, a_, b.as_allocated_mut()?)?;
89        Ok(b)
90    }
91}
92
93impl<A, Si, So> SolveTriangular<A, So, Ix2> for ArrayBase<Si, Ix2>
94where
95    A: Scalar + Lapack,
96    Si: Data<Elem = A>,
97    So: DataMut<Elem = A> + DataOwned,
98{
99    fn solve_triangular(
100        &self,
101        uplo: UPLO,
102        diag: Diag,
103        b: &ArrayBase<So, Ix2>,
104    ) -> Result<Array2<A>> {
105        let b = replicate(b);
106        self.solve_triangular_into(uplo, diag, b)
107    }
108}
109
110impl<A, Si, So> SolveTriangularInto<So, Ix1> for ArrayBase<Si, Ix2>
111where
112    A: Scalar + Lapack,
113    Si: Data<Elem = A>,
114    So: DataMut<Elem = A> + DataOwned,
115{
116    fn solve_triangular_into(
117        &self,
118        uplo: UPLO,
119        diag: Diag,
120        b: ArrayBase<So, Ix1>,
121    ) -> Result<ArrayBase<So, Ix1>> {
122        let b = into_col(b);
123        let b = self.solve_triangular_into(uplo, diag, b)?;
124        Ok(flatten(b))
125    }
126}
127
128impl<A, Si, So> SolveTriangular<A, So, Ix1> for ArrayBase<Si, Ix2>
129where
130    A: Scalar + Lapack,
131    Si: Data<Elem = A>,
132    So: DataMut<Elem = A> + DataOwned,
133{
134    fn solve_triangular(
135        &self,
136        uplo: UPLO,
137        diag: Diag,
138        b: &ArrayBase<So, Ix1>,
139    ) -> Result<Array1<A>> {
140        let b = b.to_owned();
141        self.solve_triangular_into(uplo, diag, b)
142    }
143}
144
145pub trait IntoTriangular<T> {
146    fn into_triangular(self, uplo: UPLO) -> T;
147}
148
149impl<'a, A, S> IntoTriangular<&'a mut ArrayBase<S, Ix2>> for &'a mut ArrayBase<S, Ix2>
150where
151    A: Zero,
152    S: DataMut<Elem = A>,
153{
154    fn into_triangular(self, uplo: UPLO) -> &'a mut ArrayBase<S, Ix2> {
155        match uplo {
156            UPLO::Upper => {
157                for ((i, j), val) in self.indexed_iter_mut() {
158                    if i > j {
159                        *val = A::zero();
160                    }
161                }
162            }
163            UPLO::Lower => {
164                for ((i, j), val) in self.indexed_iter_mut() {
165                    if i < j {
166                        *val = A::zero();
167                    }
168                }
169            }
170        }
171        self
172    }
173}
174
175impl<A, S> IntoTriangular<ArrayBase<S, Ix2>> for ArrayBase<S, Ix2>
176where
177    A: Zero,
178    S: DataMut<Elem = A>,
179{
180    fn into_triangular(mut self, uplo: UPLO) -> ArrayBase<S, Ix2> {
181        (&mut self).into_triangular(uplo);
182        self
183    }
184}