Skip to content

Commit

Permalink
add autograd show drop tensor name
Browse files Browse the repository at this point in the history
  • Loading branch information
bokutotu committed Nov 2, 2024
1 parent 50dc949 commit f6d0083
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 0 deletions.
1 change: 1 addition & 0 deletions zenu-autograd/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ rand = "0.8.5"
rand_distr = "0.4.3"
lazy_static = "1.4.0"
serde = { version = "1.0.197", features = ["derive"] }
once_cell = "1.19.0"

[dev-dependencies]
criterion = "0.5.1"
Expand Down
4 changes: 4 additions & 0 deletions zenu-autograd/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ pub fn concat<T: Num, D: Device>(vars: &[Variable<T, D>]) -> Variable<T, D> {
concat.forward();

output.set_creator(Rc::new(RefCell::new(Box::new(concat))));
output.set_name("concat_output");
output
}

Expand All @@ -118,6 +119,9 @@ fn concat_grad<T: Num, D: Device>(input: Variable<T, D>) -> Vec<Variable<T, D>>
Box::new(concat_grad) as Box<dyn Function<T, D>>
));
outputs.iter().for_each(|v| v.set_creator(layer.clone()));
for (idx, v) in outputs.iter().enumerate() {
v.set_name(&format!("concat_grad_output_{idx}"));
}
outputs
}

Expand Down
25 changes: 25 additions & 0 deletions zenu-autograd/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,21 @@ use zenu_matrix::{
num::Num,
};

pub(crate) struct ZenuAutogradState {
pub(crate) is_drop_name_show: bool,
}

impl Default for ZenuAutogradState {
fn default() -> Self {
let is_drop_name_show =
std::env::var("ZENU_DROP_NAME_SHOW").unwrap_or("1".to_string()) == "1";
ZenuAutogradState { is_drop_name_show }
}
}

pub(crate) static ZENU_AUTOGRAD_STATE: once_cell::sync::Lazy<ZenuAutogradState> =
once_cell::sync::Lazy::new(ZenuAutogradState::default);

pub trait Function<T: Num, D: Device> {
fn forward(&self);
fn backward(&self);
Expand Down Expand Up @@ -116,6 +131,16 @@ pub struct VariableInner<T: Num, D: Device> {
is_train: bool,
}

impl<T: Num, D: Device> Drop for VariableInner<T, D> {
fn drop(&mut self) {
if ZENU_AUTOGRAD_STATE.is_drop_name_show {
if let Some(name) = self.name.clone() {
println!("Drop Variable: {name}");
}
}
}
}

impl<T, D> Serialize for VariableInner<T, D>
where
T: Num,
Expand Down

0 comments on commit f6d0083

Please sign in to comment.