From 73493b088a4c7484fd267cbca0e8e7bcf274a6f2 Mon Sep 17 00:00:00 2001 From: bokutotu Date: Tue, 12 Nov 2024 01:50:14 +0900 Subject: [PATCH] add adamw tests --- zenu-optimizer/src/adamw.rs | 1 - zenu-optimizer/tests/net_test.rs | 11 ----------- 2 files changed, 12 deletions(-) diff --git a/zenu-optimizer/src/adamw.rs b/zenu-optimizer/src/adamw.rs index 9489820c..c9739696 100644 --- a/zenu-optimizer/src/adamw.rs +++ b/zenu-optimizer/src/adamw.rs @@ -59,7 +59,6 @@ impl> Optimizer for AdamW let update = m_hat / denom; if weight_keys.contains(&k) { - println!("Weight decay"); data.get_as_mut().sub_assign( &(data.get_as_ref() * self.learning_rate * self.weight_decay).to_ref(), ); diff --git a/zenu-optimizer/tests/net_test.rs b/zenu-optimizer/tests/net_test.rs index 13ac5693..da43c50a 100644 --- a/zenu-optimizer/tests/net_test.rs +++ b/zenu-optimizer/tests/net_test.rs @@ -194,17 +194,6 @@ fn adam_w_test() { let optimizer = AdamW::new(0.01, 0.9, 0.999, 1e-8, 0.01, &net); let _ = test_funcion_inner(&net, &optimizer); let parameters = test_funcion_inner(&net, &optimizer); - // linear1.weight tensor([[0.0801, 0.1800], - // [0.2800, 0.3800], - // [0.4800, 0.5800], - // [0.0501, 0.0601]]) - // linear1.bias tensor([0.0801, 0.1800, 0.2800, 0.3800]) - // linear2.weight tensor([[0.0801, 0.1800, 0.2800, 0.3800], - // [0.4800, 0.5800, 0.0501, 0.0601], - // [0.0702, 0.0801, 0.0901, 0.1001], - // [0.1101, 0.1201, 0.1301, 0.1401]]) - // linear2.bias tensor([0.0800, 0.1800, 0.2801, 0.3800]) - let linear1_weight = vec![ 0.0801, 0.1800, 0.2800, 0.3800, 0.4800, 0.5800, 0.0501, 0.0601, ];