diff --git a/zenu-optimizer/tests/net_test.rs b/zenu-optimizer/tests/net_test.rs index c8ed1019..4fc2e7bb 100644 --- a/zenu-optimizer/tests/net_test.rs +++ b/zenu-optimizer/tests/net_test.rs @@ -13,6 +13,7 @@ use zenu::{ }, optimizer::{adam::Adam, sgd::SGD, Optimizer}, }; + use zenu_test::assert_val_eq; #[derive(Parameters)] @@ -167,23 +168,21 @@ fn adam_test() { assert_val_eq!( parameters["linear1.linear.weight"].clone(), linear1_weight, - 1e-4 + 2e-4 ); assert_val_eq!( parameters["linear1.linear.bias"].clone(), linear1_bias, - 1e-4 + 2e-4 ); assert_val_eq!( parameters["linear2.linear.weight"].clone(), linear2_weight, - 1e-4 + 2e-4 ); assert_val_eq!( parameters["linear2.linear.bias"].clone(), linear2_bias, - 1e-4 + 2e-4 ); - - panic!(); }