forked from coreylowman/dfdx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path12-gradient-accum.rs
32 lines (24 loc) · 1.06 KB
/
12-gradient-accum.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
use dfdx::{
nn::ZeroGrads,
prelude::*,
tensor::{AutoDevice, Gradients},
};
fn main() {
let dev = AutoDevice::default();
type Model = (Linear<2, 5>, ReLU, Linear<5, 10>, Tanh, Linear<10, 20>);
let model = dev.build_module::<Model, f32>();
let x: Tensor<Rank2<10, 2>, f32, _> = dev.sample_normal();
// first we call .alloc_grads, which both pre-allocates gradients
// and also marks non-parameter gradients as temporary.
// this allows .backward() to drop temporary gradients.
let mut grads: Gradients<f32, _> = model.alloc_grads();
grads = model.forward(x.trace(grads)).mean().backward();
// backward will return the same gradients object that we passed
// into trace()
grads = model.forward(x.trace(grads)).mean().backward();
// you can do this as many times as you want!
grads = model.forward(x.trace(grads)).mean().backward();
// finally, we can use ZeroGrads to zero out the accumulated gradients
model.zero_grads(&mut grads);
assert_eq!(grads.get(&model.0.weight).array(), [[0.0; 2]; 5]);
}