ndarray_linalg/
assert.rs

1//! Assertions for array
2
3use ndarray::*;
4use std::fmt::Debug;
5
6use super::norm::*;
7use super::types::*;
8
9/// check two values are close in terms of the relative tolerance
10pub fn rclose<A: Scalar>(test: A, truth: A, rtol: A::Real) {
11    let dev = (test - truth).abs() / truth.abs();
12    if dev > rtol {
13        eprintln!("==== Assetion Failed ====");
14        eprintln!("Expected = {}", truth);
15        eprintln!("Actual   = {}", test);
16        panic!("Too large deviation in relative tolerance: {}", dev);
17    }
18}
19
20/// check two values are close in terms of the absolute tolerance
21pub fn aclose<A: Scalar>(test: A, truth: A, atol: A::Real) {
22    let dev = (test - truth).abs();
23    if dev > atol {
24        eprintln!("==== Assetion Failed ====");
25        eprintln!("Expected = {}", truth);
26        eprintln!("Actual   = {}", test);
27        panic!("Too large deviation in absolute tolerance: {}", dev);
28    }
29}
30
31/// check two arrays are close in maximum norm
32pub fn close_max<A, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, atol: A::Real)
33where
34    A: Scalar + Lapack,
35    S1: Data<Elem = A>,
36    S2: Data<Elem = A>,
37    D: Dimension,
38    D::Pattern: PartialEq + Debug,
39{
40    assert_eq!(test.dim(), truth.dim());
41    let tol = (test - truth).norm_max();
42    if tol > atol {
43        eprintln!("==== Assetion Failed ====");
44        eprintln!("Expected:\n{}", truth);
45        eprintln!("Actual:\n{}", test);
46        panic!("Too large deviation in maximum norm: {} > {}", tol, atol);
47    }
48}
49
50/// check two arrays are close in L1 norm
51pub fn close_l1<A, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: A::Real)
52where
53    A: Scalar + Lapack,
54    S1: Data<Elem = A>,
55    S2: Data<Elem = A>,
56    D: Dimension,
57    D::Pattern: PartialEq + Debug,
58{
59    assert_eq!(test.dim(), truth.dim());
60    let tol = (test - truth).norm_l1() / truth.norm_l1();
61    if tol > rtol {
62        eprintln!("==== Assetion Failed ====");
63        eprintln!("Expected:\n{}", truth);
64        eprintln!("Actual:\n{}", test);
65        panic!("Too large deviation in L1-norm: {} > {}", tol, rtol);
66    }
67}
68
69/// check two arrays are close in L2 norm
70pub fn close_l2<A, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: A::Real)
71where
72    A: Scalar + Lapack,
73    S1: Data<Elem = A>,
74    S2: Data<Elem = A>,
75    D: Dimension,
76    D::Pattern: PartialEq + Debug,
77{
78    assert_eq!(test.dim(), truth.dim());
79    let tol = (test - truth).norm_l2() / truth.norm_l2();
80    if tol > rtol {
81        eprintln!("==== Assetion Failed ====");
82        eprintln!("Expected:\n{}", truth);
83        eprintln!("Actual:\n{}", test);
84        panic!("Too large deviation in L2-norm: {} > {} ", tol, rtol);
85    }
86}
87
88macro_rules! generate_assert {
89    ($assert:ident, $close:path) => {
90        #[macro_export]
91        macro_rules! $assert {
92            ($test: expr,$truth: expr,$tol: expr) => {
93                $crate::$close($test, $truth, $tol);
94            };
95            ($test: expr,$truth: expr,$tol: expr; $comment: expr) => {
96                eprintln!($comment);
97                $crate::$close($test, $truth, $tol);
98            };
99        }
100    };
101} // generate_assert!
102
103generate_assert!(assert_rclose, rclose);
104generate_assert!(assert_aclose, aclose);
105generate_assert!(assert_close_max, close_max);
106generate_assert!(assert_close_l1, close_l1);
107generate_assert!(assert_close_l2, close_l2);