1use ndarray::*;
4use std::fmt::Debug;
5
6use super::norm::*;
7use super::types::*;
8
9pub fn rclose<A: Scalar>(test: A, truth: A, rtol: A::Real) {
11 let dev = (test - truth).abs() / truth.abs();
12 if dev > rtol {
13 eprintln!("==== Assetion Failed ====");
14 eprintln!("Expected = {}", truth);
15 eprintln!("Actual = {}", test);
16 panic!("Too large deviation in relative tolerance: {}", dev);
17 }
18}
19
20pub fn aclose<A: Scalar>(test: A, truth: A, atol: A::Real) {
22 let dev = (test - truth).abs();
23 if dev > atol {
24 eprintln!("==== Assetion Failed ====");
25 eprintln!("Expected = {}", truth);
26 eprintln!("Actual = {}", test);
27 panic!("Too large deviation in absolute tolerance: {}", dev);
28 }
29}
30
31pub fn close_max<A, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, atol: A::Real)
33where
34 A: Scalar + Lapack,
35 S1: Data<Elem = A>,
36 S2: Data<Elem = A>,
37 D: Dimension,
38 D::Pattern: PartialEq + Debug,
39{
40 assert_eq!(test.dim(), truth.dim());
41 let tol = (test - truth).norm_max();
42 if tol > atol {
43 eprintln!("==== Assetion Failed ====");
44 eprintln!("Expected:\n{}", truth);
45 eprintln!("Actual:\n{}", test);
46 panic!("Too large deviation in maximum norm: {} > {}", tol, atol);
47 }
48}
49
50pub fn close_l1<A, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: A::Real)
52where
53 A: Scalar + Lapack,
54 S1: Data<Elem = A>,
55 S2: Data<Elem = A>,
56 D: Dimension,
57 D::Pattern: PartialEq + Debug,
58{
59 assert_eq!(test.dim(), truth.dim());
60 let tol = (test - truth).norm_l1() / truth.norm_l1();
61 if tol > rtol {
62 eprintln!("==== Assetion Failed ====");
63 eprintln!("Expected:\n{}", truth);
64 eprintln!("Actual:\n{}", test);
65 panic!("Too large deviation in L1-norm: {} > {}", tol, rtol);
66 }
67}
68
69pub fn close_l2<A, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: A::Real)
71where
72 A: Scalar + Lapack,
73 S1: Data<Elem = A>,
74 S2: Data<Elem = A>,
75 D: Dimension,
76 D::Pattern: PartialEq + Debug,
77{
78 assert_eq!(test.dim(), truth.dim());
79 let tol = (test - truth).norm_l2() / truth.norm_l2();
80 if tol > rtol {
81 eprintln!("==== Assetion Failed ====");
82 eprintln!("Expected:\n{}", truth);
83 eprintln!("Actual:\n{}", test);
84 panic!("Too large deviation in L2-norm: {} > {} ", tol, rtol);
85 }
86}
87
88macro_rules! generate_assert {
89 ($assert:ident, $close:path) => {
90 #[macro_export]
91 macro_rules! $assert {
92 ($test: expr,$truth: expr,$tol: expr) => {
93 $crate::$close($test, $truth, $tol);
94 };
95 ($test: expr,$truth: expr,$tol: expr; $comment: expr) => {
96 eprintln!($comment);
97 $crate::$close($test, $truth, $tol);
98 };
99 }
100 };
101} generate_assert!(assert_rclose, rclose);
104generate_assert!(assert_aclose, aclose);
105generate_assert!(assert_close_max, close_max);
106generate_assert!(assert_close_l1, close_l1);
107generate_assert!(assert_close_l2, close_l2);