Skip to content

Commit

Permalink
Fix/keep consistent with fsrs-optimizer (#246)
Browse files Browse the repository at this point in the history
* set eps of adam to 1e-8

* set seed to 2023

* filter and sort train_set when benchmark

* locate the bug related to parameter_clipper

* fix parameter_clipper

* update to burn v0.13.2

* bump version

* make max_seq_len configurable & sort by length
  • Loading branch information
L-M-Sherlock authored Oct 28, 2024
1 parent 1719019 commit 9d5d2f3
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 52 deletions.
62 changes: 31 additions & 31 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "1.4.0"
version = "1.4.1"
authors = ["Open Spaced Repetition"]
categories = ["algorithms", "science"]
edition = "2021"
Expand All @@ -15,15 +15,15 @@ description = "FSRS for Rust, including Optimizer and Scheduler"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies.burn]
version = "0.13.1"
version = "0.13.2"
# git = "https://github.com/tracel-ai/burn.git"
# rev = "6ae3926006872a204869e84ffc303417c54b6b7f"
# path = "../burn/burn"
default-features = false
features = ["std", "train", "ndarray"]

[dev-dependencies.burn]
version = "0.13.1"
version = "0.13.2"
# git = "https://github.com/tracel-ai/burn.git"
# rev = "6ae3926006872a204869e84ffc303417c54b6b7f"
# path = "../burn/burn"
Expand Down
24 changes: 17 additions & 7 deletions src/parameter_clipper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,23 @@ use crate::{
inference::{Parameters, S_MIN},
pre_training::INIT_S_MAX,
};
use burn::tensor::{backend::Backend, Data, Tensor};
use burn::{
module::Param,
tensor::{backend::Backend, Data, Tensor},
};

pub(crate) fn parameter_clipper<B: Backend>(parameters: Tensor<B, 1>) -> Tensor<B, 1> {
let val = clip_parameters(&parameters.to_data().convert().value);
Tensor::from_data(
Data::new(val, parameters.shape()).convert(),
&B::Device::default(),
pub(crate) fn parameter_clipper<B: Backend>(
parameters: Param<Tensor<B, 1>>,
) -> Param<Tensor<B, 1>> {
let (id, val) = parameters.consume();
let clipped = clip_parameters(&val.to_data().convert().value);
Param::initialized(
id,
Tensor::from_data(
Data::new(clipped, val.shape()).convert(),
&B::Device::default(),
)
.require_grad(),
)
}

Expand Down Expand Up @@ -58,7 +68,7 @@ mod tests {
&device,
);

let param: Tensor<1> = parameter_clipper(tensor);
let param = parameter_clipper(Param::from_tensor(tensor));
let values = &param.to_data().value;

assert_eq!(
Expand Down
Loading

0 comments on commit 9d5d2f3

Please sign in to comment.