Skip to content

Commit

Permalink
add adamw tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bokutotu committed Nov 11, 2024
1 parent 86a6d01 commit 73493b0
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 12 deletions.
1 change: 0 additions & 1 deletion zenu-optimizer/src/adamw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ impl<T: Num, D: Device, P: Parameters<T, D>> Optimizer<T, D, P> for AdamW<T, D>
let update = m_hat / denom;

if weight_keys.contains(&k) {
println!("Weight decay");
data.get_as_mut().sub_assign(
&(data.get_as_ref() * self.learning_rate * self.weight_decay).to_ref(),
);
Expand Down
11 changes: 0 additions & 11 deletions zenu-optimizer/tests/net_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,17 +194,6 @@ fn adam_w_test() {
let optimizer = AdamW::new(0.01, 0.9, 0.999, 1e-8, 0.01, &net);
let _ = test_funcion_inner(&net, &optimizer);
let parameters = test_funcion_inner(&net, &optimizer);
// linear1.weight tensor([[0.0801, 0.1800],
// [0.2800, 0.3800],
// [0.4800, 0.5800],
// [0.0501, 0.0601]])
// linear1.bias tensor([0.0801, 0.1800, 0.2800, 0.3800])
// linear2.weight tensor([[0.0801, 0.1800, 0.2800, 0.3800],
// [0.4800, 0.5800, 0.0501, 0.0601],
// [0.0702, 0.0801, 0.0901, 0.1001],
// [0.1101, 0.1201, 0.1301, 0.1401]])
// linear2.bias tensor([0.0800, 0.1800, 0.2801, 0.3800])

let linear1_weight = vec![
0.0801, 0.1800, 0.2800, 0.3800, 0.4800, 0.5800, 0.0501, 0.0601,
];
Expand Down

0 comments on commit 73493b0

Please sign in to comment.