Skip to content

Commit

Permalink
Revert "remove pow polyfill" (#240)
Browse files Browse the repository at this point in the history
* Revert "remove pow polyfill (#195)"

This reverts commit 1057c9c.

* bump version

* add test for loss and grad
  • Loading branch information
L-M-Sherlock authored Oct 20, 2024
1 parent 81af1fa commit 5f43bc5
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "1.3.3"
version = "1.3.4"
authors = ["Open Spaced Repetition"]
categories = ["algorithms", "science"]
edition = "2021"
Expand Down
18 changes: 15 additions & 3 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ impl<B: Backend, const N: usize> Get<B, N> for Tensor<B, N> {
}
}

trait Pow<B: Backend, const N: usize> {
// https://github.com/burn-rs/burn/issues/590 , after that finished, just remove this trait and below impl, all will ok.
fn pow(&self, other: Tensor<B, N>) -> Tensor<B, N>;
}

impl<B: Backend, const N: usize> Pow<B, N> for Tensor<B, N> {
fn pow(&self, other: Self) -> Self {
// a ^ b => exp(ln(a^b)) => exp(b ln (a))
(self.clone().log() * other).exp()
}
}

impl<B: Backend> Model<B> {
#[allow(clippy::new_without_default)]
pub fn new(config: ModelConfig) -> Self {
Expand Down Expand Up @@ -65,7 +77,7 @@ impl<B: Backend> Model<B> {
last_s.clone()
* (self.w.get(8).exp()
* (-last_d + 11)
* (last_s.powf_scalar(self.w.get(9).neg().into_scalar()))
* (last_s.pow(-self.w.get(9)))
* (((-r + 1) * self.w.get(10)).exp() - 1)
* hard_penalty
* easy_bonus
Expand All @@ -79,8 +91,8 @@ impl<B: Backend> Model<B> {
r: Tensor<B, 1>,
) -> Tensor<B, 1> {
let new_s = self.w.get(11)
* last_d.powf_scalar(self.w.get(12).neg().into_scalar())
* ((last_s.clone() + 1).powf_scalar(self.w.get(13).into_scalar()) - 1)
* last_d.pow(-self.w.get(12))
* ((last_s.clone() + 1).pow(self.w.get(13)) - 1)
* ((-r + 1) * self.w.get(14)).exp();
new_s
.clone()
Expand Down
60 changes: 60 additions & 0 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,8 @@ mod tests {
use super::*;
use crate::convertor_tests::anki21_sample_file_converted_to_fsrs;
use crate::convertor_tests::data_from_csv;
use crate::dataset::FSRSBatch;
use burn::backend::NdArray;
use log::LevelFilter;

#[test]
Expand All @@ -450,6 +452,64 @@ mod tests {
assert_eq!(average_recall, 0.9435269);
}

#[test]
fn test_loss_and_grad() {
use burn::backend::ndarray::NdArrayDevice;
use burn::tensor::Data;

let config = ModelConfig::default();
let device = NdArrayDevice::Cpu;
let model: Model<Autodiff<NdArray<f32>>> = config.init();

let item = FSRSBatch {
t_historys: Tensor::from_floats(
Data::from([
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
[0.0, 1.0, 1.0, 3.0],
[1.0, 3.0, 3.0, 5.0],
[3.0, 6.0, 6.0, 12.0],
]),
&device,
),
r_historys: Tensor::from_floats(
Data::from([
[1.0, 2.0, 3.0, 4.0],
[3.0, 4.0, 2.0, 4.0],
[1.0, 4.0, 4.0, 3.0],
[4.0, 3.0, 3.0, 3.0],
[3.0, 1.0, 3.0, 3.0],
[2.0, 3.0, 3.0, 4.0],
]),
&device,
),
delta_ts: Tensor::from_floats(Data::from([4.0, 11.0, 12.0, 23.0]), &device),
labels: Tensor::from_ints(Data::from([1, 1, 1, 0]), &device),
};

let loss = model.forward_classification(
item.t_historys,
item.r_historys,
item.delta_ts,
item.labels,
Reduction::Sum,
);

assert_eq!(loss.clone().into_data().convert::<f32>().value[0], 4.380769);
let gradients = loss.backward();

let w_grad = model.w.grad(&gradients).unwrap();
dbg!(&w_grad);

Data::from([
-0.044447, -0.004000, -0.002020, 0.009756, -0.036012, 1.126084, 0.101431, -0.888184,
0.540923, -2.830812, 0.492003, -0.008362, 0.024086, -0.077360, -0.000585, -0.135484,
0.203740, 0.208560, 0.037535,
])
.assert_approx_eq(&w_grad.clone().into_data(), 5);
}

#[test]
fn training() {
if std::env::var("SKIP_TRAINING").is_ok() {
Expand Down

0 comments on commit 5f43bc5

Please sign in to comment.