1use lax::*;
4use ndarray::*;
5use num_traits::Zero;
6
7use super::convert::*;
8use super::error::*;
9use super::layout::*;
10use super::types::*;
11
12pub use lax::Diag;
13
14pub trait SolveTriangular<A, S, D>
16where
17 A: Scalar + Lapack,
18 S: Data<Elem = A>,
19 D: Dimension,
20{
21 fn solve_triangular(&self, uplo: UPLO, diag: Diag, b: &ArrayBase<S, D>) -> Result<Array<A, D>>;
22}
23
24pub trait SolveTriangularInto<S, D>
26where
27 S: DataMut,
28 D: Dimension,
29{
30 fn solve_triangular_into(
31 &self,
32 uplo: UPLO,
33 diag: Diag,
34 b: ArrayBase<S, D>,
35 ) -> Result<ArrayBase<S, D>>;
36}
37
38pub trait SolveTriangularInplace<S, D>
40where
41 S: DataMut,
42 D: Dimension,
43{
44 fn solve_triangular_inplace<'a>(
45 &self,
46 uplo: UPLO,
47 diag: Diag,
48 b: &'a mut ArrayBase<S, D>,
49 ) -> Result<&'a mut ArrayBase<S, D>>;
50}
51
52impl<A, Si, So> SolveTriangularInto<So, Ix2> for ArrayBase<Si, Ix2>
53where
54 A: Scalar + Lapack,
55 Si: Data<Elem = A>,
56 So: DataMut<Elem = A> + DataOwned,
57{
58 fn solve_triangular_into(
59 &self,
60 uplo: UPLO,
61 diag: Diag,
62 mut b: ArrayBase<So, Ix2>,
63 ) -> Result<ArrayBase<So, Ix2>> {
64 self.solve_triangular_inplace(uplo, diag, &mut b)?;
65 Ok(b)
66 }
67}
68
69impl<A, Si, So> SolveTriangularInplace<So, Ix2> for ArrayBase<Si, Ix2>
70where
71 A: Scalar + Lapack,
72 Si: Data<Elem = A>,
73 So: DataMut<Elem = A> + DataOwned,
74{
75 fn solve_triangular_inplace<'a>(
76 &self,
77 uplo: UPLO,
78 diag: Diag,
79 b: &'a mut ArrayBase<So, Ix2>,
80 ) -> Result<&'a mut ArrayBase<So, Ix2>> {
81 let la = self.layout()?;
82 let a_ = self.as_allocated()?;
83 let lb = b.layout()?;
84 if !la.same_order(&lb) {
85 transpose_data(b)?;
86 }
87 let lb = b.layout()?;
88 A::solve_triangular(la, lb, uplo, diag, a_, b.as_allocated_mut()?)?;
89 Ok(b)
90 }
91}
92
93impl<A, Si, So> SolveTriangular<A, So, Ix2> for ArrayBase<Si, Ix2>
94where
95 A: Scalar + Lapack,
96 Si: Data<Elem = A>,
97 So: DataMut<Elem = A> + DataOwned,
98{
99 fn solve_triangular(
100 &self,
101 uplo: UPLO,
102 diag: Diag,
103 b: &ArrayBase<So, Ix2>,
104 ) -> Result<Array2<A>> {
105 let b = replicate(b);
106 self.solve_triangular_into(uplo, diag, b)
107 }
108}
109
110impl<A, Si, So> SolveTriangularInto<So, Ix1> for ArrayBase<Si, Ix2>
111where
112 A: Scalar + Lapack,
113 Si: Data<Elem = A>,
114 So: DataMut<Elem = A> + DataOwned,
115{
116 fn solve_triangular_into(
117 &self,
118 uplo: UPLO,
119 diag: Diag,
120 b: ArrayBase<So, Ix1>,
121 ) -> Result<ArrayBase<So, Ix1>> {
122 let b = into_col(b);
123 let b = self.solve_triangular_into(uplo, diag, b)?;
124 Ok(flatten(b))
125 }
126}
127
128impl<A, Si, So> SolveTriangular<A, So, Ix1> for ArrayBase<Si, Ix2>
129where
130 A: Scalar + Lapack,
131 Si: Data<Elem = A>,
132 So: DataMut<Elem = A> + DataOwned,
133{
134 fn solve_triangular(
135 &self,
136 uplo: UPLO,
137 diag: Diag,
138 b: &ArrayBase<So, Ix1>,
139 ) -> Result<Array1<A>> {
140 let b = b.to_owned();
141 self.solve_triangular_into(uplo, diag, b)
142 }
143}
144
145pub trait IntoTriangular<T> {
146 fn into_triangular(self, uplo: UPLO) -> T;
147}
148
149impl<'a, A, S> IntoTriangular<&'a mut ArrayBase<S, Ix2>> for &'a mut ArrayBase<S, Ix2>
150where
151 A: Zero,
152 S: DataMut<Elem = A>,
153{
154 fn into_triangular(self, uplo: UPLO) -> &'a mut ArrayBase<S, Ix2> {
155 match uplo {
156 UPLO::Upper => {
157 for ((i, j), val) in self.indexed_iter_mut() {
158 if i > j {
159 *val = A::zero();
160 }
161 }
162 }
163 UPLO::Lower => {
164 for ((i, j), val) in self.indexed_iter_mut() {
165 if i < j {
166 *val = A::zero();
167 }
168 }
169 }
170 }
171 self
172 }
173}
174
175impl<A, S> IntoTriangular<ArrayBase<S, Ix2>> for ArrayBase<S, Ix2>
176where
177 A: Zero,
178 S: DataMut<Elem = A>,
179{
180 fn into_triangular(mut self, uplo: UPLO) -> ArrayBase<S, Ix2> {
181 (&mut self).into_triangular(uplo);
182 self
183 }
184}