use super::convert::*;
use super::error::*;
use super::layout::*;
use cauchy::Scalar;
use lax::*;
use ndarray::*;
use num_traits::One;
pub use lax::{LUFactorizedTridiagonal, Tridiagonal};
pub trait ExtractTridiagonal<A: Scalar> {
fn extract_tridiagonal(&self) -> Result<Tridiagonal<A>>;
}
impl<A, S> ExtractTridiagonal<A> for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn extract_tridiagonal(&self) -> Result<Tridiagonal<A>> {
let l = self.square_layout()?;
let (n, _) = l.size();
if n < 2 {
return Err(LinalgError::NotStandardShape {
obj: "Tridiagonal",
rows: 1,
cols: 1,
});
}
let dl = self.slice(s![1..n, 0..n - 1]).diag().to_vec();
let d = self.diag().to_vec();
let du = self.slice(s![0..n - 1, 1..n]).diag().to_vec();
Ok(Tridiagonal { l, dl, d, du })
}
}
pub trait SolveTridiagonal<A: Scalar, D: Dimension> {
fn solve_tridiagonal<S: Data<Elem = A>>(&self, b: &ArrayBase<S, D>) -> Result<Array<A, D>>;
fn solve_tridiagonal_into<S: DataMut<Elem = A>>(
&self,
b: ArrayBase<S, D>,
) -> Result<ArrayBase<S, D>>;
fn solve_t_tridiagonal<S: Data<Elem = A>>(&self, b: &ArrayBase<S, D>) -> Result<Array<A, D>>;
fn solve_t_tridiagonal_into<S: DataMut<Elem = A>>(
&self,
b: ArrayBase<S, D>,
) -> Result<ArrayBase<S, D>>;
fn solve_h_tridiagonal<S: Data<Elem = A>>(&self, b: &ArrayBase<S, D>) -> Result<Array<A, D>>;
fn solve_h_tridiagonal_into<S: DataMut<Elem = A>>(
&self,
b: ArrayBase<S, D>,
) -> Result<ArrayBase<S, D>>;
}
pub trait SolveTridiagonalInplace<A: Scalar, D: Dimension> {
fn solve_tridiagonal_inplace<'a, S: DataMut<Elem = A>>(
&self,
b: &'a mut ArrayBase<S, D>,
) -> Result<&'a mut ArrayBase<S, D>>;
fn solve_t_tridiagonal_inplace<'a, S: DataMut<Elem = A>>(
&self,
b: &'a mut ArrayBase<S, D>,
) -> Result<&'a mut ArrayBase<S, D>>;
fn solve_h_tridiagonal_inplace<'a, S: DataMut<Elem = A>>(
&self,
b: &'a mut ArrayBase<S, D>,
) -> Result<&'a mut ArrayBase<S, D>>;
}
impl<A> SolveTridiagonal<A, Ix2> for LUFactorizedTridiagonal<A>
where
A: Scalar + Lapack,
{
fn solve_tridiagonal<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix2>) -> Result<Array<A, Ix2>> {
let mut b = replicate(b);
self.solve_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_tridiagonal_into<S: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<S, Ix2>,
) -> Result<ArrayBase<S, Ix2>> {
self.solve_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_t_tridiagonal<S: Data<Elem = A>>(
&self,
b: &ArrayBase<S, Ix2>,
) -> Result<Array<A, Ix2>> {
let mut b = replicate(b);
self.solve_t_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_t_tridiagonal_into<S: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<S, Ix2>,
) -> Result<ArrayBase<S, Ix2>> {
self.solve_t_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_h_tridiagonal<S: Data<Elem = A>>(
&self,
b: &ArrayBase<S, Ix2>,
) -> Result<Array<A, Ix2>> {
let mut b = replicate(b);
self.solve_h_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_h_tridiagonal_into<S: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<S, Ix2>,
) -> Result<ArrayBase<S, Ix2>> {
self.solve_h_tridiagonal_inplace(&mut b)?;
Ok(b)
}
}
impl<A> SolveTridiagonal<A, Ix2> for Tridiagonal<A>
where
A: Scalar + Lapack,
{
fn solve_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix2>,
) -> Result<Array<A, Ix2>> {
let mut b = replicate(b);
self.solve_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<Sb, Ix2>,
) -> Result<ArrayBase<Sb, Ix2>> {
self.solve_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_t_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix2>,
) -> Result<Array<A, Ix2>> {
let mut b = replicate(b);
self.solve_t_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_t_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<Sb, Ix2>,
) -> Result<ArrayBase<Sb, Ix2>> {
self.solve_t_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_h_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix2>,
) -> Result<Array<A, Ix2>> {
let mut b = replicate(b);
self.solve_h_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_h_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<Sb, Ix2>,
) -> Result<ArrayBase<Sb, Ix2>> {
self.solve_h_tridiagonal_inplace(&mut b)?;
Ok(b)
}
}
impl<A, S> SolveTridiagonal<A, Ix2> for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn solve_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix2>,
) -> Result<Array<A, Ix2>> {
let mut b = replicate(b);
self.solve_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<Sb, Ix2>,
) -> Result<ArrayBase<Sb, Ix2>> {
self.solve_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_t_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix2>,
) -> Result<Array<A, Ix2>> {
let mut b = replicate(b);
self.solve_t_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_t_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<Sb, Ix2>,
) -> Result<ArrayBase<Sb, Ix2>> {
self.solve_t_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_h_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix2>,
) -> Result<Array<A, Ix2>> {
let mut b = replicate(b);
self.solve_h_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_h_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<Sb, Ix2>,
) -> Result<ArrayBase<Sb, Ix2>> {
self.solve_h_tridiagonal_inplace(&mut b)?;
Ok(b)
}
}
impl<A> SolveTridiagonalInplace<A, Ix2> for LUFactorizedTridiagonal<A>
where
A: Scalar + Lapack,
{
fn solve_tridiagonal_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix2>,
) -> Result<&'a mut ArrayBase<Sb, Ix2>>
where
Sb: DataMut<Elem = A>,
{
A::solve_tridiagonal(
self,
rhs.layout()?,
Transpose::No,
rhs.as_slice_mut().unwrap(),
)?;
Ok(rhs)
}
fn solve_t_tridiagonal_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix2>,
) -> Result<&'a mut ArrayBase<Sb, Ix2>>
where
Sb: DataMut<Elem = A>,
{
A::solve_tridiagonal(
self,
rhs.layout()?,
Transpose::Transpose,
rhs.as_slice_mut().unwrap(),
)?;
Ok(rhs)
}
fn solve_h_tridiagonal_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix2>,
) -> Result<&'a mut ArrayBase<Sb, Ix2>>
where
Sb: DataMut<Elem = A>,
{
A::solve_tridiagonal(
self,
rhs.layout()?,
Transpose::Hermite,
rhs.as_slice_mut().unwrap(),
)?;
Ok(rhs)
}
}
impl<A> SolveTridiagonalInplace<A, Ix2> for Tridiagonal<A>
where
A: Scalar + Lapack,
{
fn solve_tridiagonal_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix2>,
) -> Result<&'a mut ArrayBase<Sb, Ix2>>
where
Sb: DataMut<Elem = A>,
{
let f = self.factorize_tridiagonal()?;
f.solve_tridiagonal_inplace(rhs)
}
fn solve_t_tridiagonal_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix2>,
) -> Result<&'a mut ArrayBase<Sb, Ix2>>
where
Sb: DataMut<Elem = A>,
{
let f = self.factorize_tridiagonal()?;
f.solve_t_tridiagonal_inplace(rhs)
}
fn solve_h_tridiagonal_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix2>,
) -> Result<&'a mut ArrayBase<Sb, Ix2>>
where
Sb: DataMut<Elem = A>,
{
let f = self.factorize_tridiagonal()?;
f.solve_h_tridiagonal_inplace(rhs)
}
}
impl<A, S> SolveTridiagonalInplace<A, Ix2> for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn solve_tridiagonal_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix2>,
) -> Result<&'a mut ArrayBase<Sb, Ix2>>
where
Sb: DataMut<Elem = A>,
{
let f = self.factorize_tridiagonal()?;
f.solve_tridiagonal_inplace(rhs)
}
fn solve_t_tridiagonal_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix2>,
) -> Result<&'a mut ArrayBase<Sb, Ix2>>
where
Sb: DataMut<Elem = A>,
{
let f = self.factorize_tridiagonal()?;
f.solve_t_tridiagonal_inplace(rhs)
}
fn solve_h_tridiagonal_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix2>,
) -> Result<&'a mut ArrayBase<Sb, Ix2>>
where
Sb: DataMut<Elem = A>,
{
let f = self.factorize_tridiagonal()?;
f.solve_h_tridiagonal_inplace(rhs)
}
}
impl<A> SolveTridiagonal<A, Ix1> for LUFactorizedTridiagonal<A>
where
A: Scalar + Lapack,
{
fn solve_tridiagonal<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array<A, Ix1>> {
let b = b.to_owned();
self.solve_tridiagonal_into(b)
}
fn solve_tridiagonal_into<S: DataMut<Elem = A>>(
&self,
b: ArrayBase<S, Ix1>,
) -> Result<ArrayBase<S, Ix1>> {
let b = into_col(b);
let b = self.solve_tridiagonal_into(b)?;
Ok(flatten(b))
}
fn solve_t_tridiagonal<S: Data<Elem = A>>(
&self,
b: &ArrayBase<S, Ix1>,
) -> Result<Array<A, Ix1>> {
let b = b.to_owned();
self.solve_t_tridiagonal_into(b)
}
fn solve_t_tridiagonal_into<S: DataMut<Elem = A>>(
&self,
b: ArrayBase<S, Ix1>,
) -> Result<ArrayBase<S, Ix1>> {
let b = into_col(b);
let b = self.solve_t_tridiagonal_into(b)?;
Ok(flatten(b))
}
fn solve_h_tridiagonal<S: Data<Elem = A>>(
&self,
b: &ArrayBase<S, Ix1>,
) -> Result<Array<A, Ix1>> {
let b = b.to_owned();
self.solve_h_tridiagonal_into(b)
}
fn solve_h_tridiagonal_into<S: DataMut<Elem = A>>(
&self,
b: ArrayBase<S, Ix1>,
) -> Result<ArrayBase<S, Ix1>> {
let b = into_col(b);
let b = self.solve_h_tridiagonal_into(b)?;
Ok(flatten(b))
}
}
impl<A> SolveTridiagonal<A, Ix1> for Tridiagonal<A>
where
A: Scalar + Lapack,
{
fn solve_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix1>,
) -> Result<Array<A, Ix1>> {
let b = b.to_owned();
self.solve_tridiagonal_into(b)
}
fn solve_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
b: ArrayBase<Sb, Ix1>,
) -> Result<ArrayBase<Sb, Ix1>> {
let b = into_col(b);
let f = self.factorize_tridiagonal()?;
let b = f.solve_tridiagonal_into(b)?;
Ok(flatten(b))
}
fn solve_t_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix1>,
) -> Result<Array<A, Ix1>> {
let b = b.to_owned();
self.solve_t_tridiagonal_into(b)
}
fn solve_t_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
b: ArrayBase<Sb, Ix1>,
) -> Result<ArrayBase<Sb, Ix1>> {
let b = into_col(b);
let f = self.factorize_tridiagonal()?;
let b = f.solve_t_tridiagonal_into(b)?;
Ok(flatten(b))
}
fn solve_h_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix1>,
) -> Result<Array<A, Ix1>> {
let b = b.to_owned();
self.solve_h_tridiagonal_into(b)
}
fn solve_h_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
b: ArrayBase<Sb, Ix1>,
) -> Result<ArrayBase<Sb, Ix1>> {
let b = into_col(b);
let f = self.factorize_tridiagonal()?;
let b = f.solve_h_tridiagonal_into(b)?;
Ok(flatten(b))
}
}
impl<A, S> SolveTridiagonal<A, Ix1> for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn solve_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix1>,
) -> Result<Array<A, Ix1>> {
let b = b.to_owned();
self.solve_tridiagonal_into(b)
}
fn solve_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
b: ArrayBase<Sb, Ix1>,
) -> Result<ArrayBase<Sb, Ix1>> {
let b = into_col(b);
let f = self.factorize_tridiagonal()?;
let b = f.solve_tridiagonal_into(b)?;
Ok(flatten(b))
}
fn solve_t_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix1>,
) -> Result<Array<A, Ix1>> {
let b = b.to_owned();
self.solve_t_tridiagonal_into(b)
}
fn solve_t_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
b: ArrayBase<Sb, Ix1>,
) -> Result<ArrayBase<Sb, Ix1>> {
let b = into_col(b);
let f = self.factorize_tridiagonal()?;
let b = f.solve_t_tridiagonal_into(b)?;
Ok(flatten(b))
}
fn solve_h_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix1>,
) -> Result<Array<A, Ix1>> {
let b = b.to_owned();
self.solve_h_tridiagonal_into(b)
}
fn solve_h_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
b: ArrayBase<Sb, Ix1>,
) -> Result<ArrayBase<Sb, Ix1>> {
let b = into_col(b);
let f = self.factorize_tridiagonal()?;
let b = f.solve_h_tridiagonal_into(b)?;
Ok(flatten(b))
}
}
pub trait FactorizeTridiagonal<A: Scalar> {
fn factorize_tridiagonal(&self) -> Result<LUFactorizedTridiagonal<A>>;
}
pub trait FactorizeTridiagonalInto<A: Scalar> {
fn factorize_tridiagonal_into(self) -> Result<LUFactorizedTridiagonal<A>>;
}
impl<A> FactorizeTridiagonalInto<A> for Tridiagonal<A>
where
A: Scalar + Lapack,
{
fn factorize_tridiagonal_into(self) -> Result<LUFactorizedTridiagonal<A>> {
Ok(A::lu_tridiagonal(self)?)
}
}
impl<A> FactorizeTridiagonal<A> for Tridiagonal<A>
where
A: Scalar + Lapack,
{
fn factorize_tridiagonal(&self) -> Result<LUFactorizedTridiagonal<A>> {
let a = self.clone();
Ok(A::lu_tridiagonal(a)?)
}
}
impl<A, S> FactorizeTridiagonal<A> for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn factorize_tridiagonal(&self) -> Result<LUFactorizedTridiagonal<A>> {
let a = self.extract_tridiagonal()?;
Ok(A::lu_tridiagonal(a)?)
}
}
fn rec_rel<A: Scalar>(tridiag: &Tridiagonal<A>) -> Vec<A> {
let n = tridiag.d.len();
let mut f = Vec::with_capacity(n + 1);
f.push(One::one());
f.push(tridiag.d[0]);
for i in 1..n {
f.push(tridiag.d[i] * f[i] - tridiag.dl[i - 1] * tridiag.du[i - 1] * f[i - 1]);
}
f
}
pub trait DeterminantTridiagonal<A: Scalar> {
fn det_tridiagonal(&self) -> Result<A>;
}
impl<A> DeterminantTridiagonal<A> for Tridiagonal<A>
where
A: Scalar,
{
fn det_tridiagonal(&self) -> Result<A> {
let n = self.d.len();
Ok(rec_rel(self)[n])
}
}
impl<A, S> DeterminantTridiagonal<A> for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn det_tridiagonal(&self) -> Result<A> {
let tridiag = self.extract_tridiagonal()?;
let n = tridiag.d.len();
Ok(rec_rel(&tridiag)[n])
}
}
pub trait ReciprocalConditionNumTridiagonal<A: Scalar> {
fn rcond_tridiagonal(&self) -> Result<A::Real>;
}
pub trait ReciprocalConditionNumTridiagonalInto<A: Scalar> {
fn rcond_tridiagonal_into(self) -> Result<A::Real>;
}
impl<A> ReciprocalConditionNumTridiagonal<A> for LUFactorizedTridiagonal<A>
where
A: Scalar + Lapack,
{
fn rcond_tridiagonal(&self) -> Result<A::Real> {
Ok(A::rcond_tridiagonal(self)?)
}
}
impl<A> ReciprocalConditionNumTridiagonalInto<A> for LUFactorizedTridiagonal<A>
where
A: Scalar + Lapack,
{
fn rcond_tridiagonal_into(self) -> Result<A::Real> {
self.rcond_tridiagonal()
}
}
impl<A, S> ReciprocalConditionNumTridiagonal<A> for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn rcond_tridiagonal(&self) -> Result<A::Real> {
self.factorize_tridiagonal()?.rcond_tridiagonal_into()
}
}