1use ndarray::*;
4use std::iter::Sum;
5
6use super::error::*;
7use super::types::*;
8
9pub trait Trace {
10 type Output;
11 fn trace(&self) -> Result<Self::Output>;
12}
13
14impl<A, S> Trace for ArrayBase<S, Ix2>
15where
16 A: Scalar + Sum,
17 S: Data<Elem = A>,
18{
19 type Output = A;
20
21 fn trace(&self) -> Result<Self::Output> {
22 let n = match self.is_square() {
23 true => Ok(self.nrows()),
24 false => Err(LinalgError::NotSquare {
25 rows: self.nrows() as i32,
26 cols: self.ncols() as i32,
27 }),
28 }?;
29 Ok((0..n as usize).map(|i| self[(i, i)]).sum())
30 }
31}