diff --git a/src/pre_training.rs b/src/pre_training.rs index eae87e5..e6f96b0 100644 --- a/src/pre_training.rs +++ b/src/pre_training.rs @@ -184,8 +184,8 @@ pub(crate) fn smooth_and_fill( } } - let w1 = 3.0 / 5.0; - let w2 = 3.0 / 5.0; + let w1 = 0.41; + let w2 = 0.54; let mut init_s0 = vec![]; @@ -357,7 +357,7 @@ mod tests { let items = [pretrainset.clone(), trainset].concat(); let average_recall = calculate_average_recall(&items); Data::from(pretrain(pretrainset, average_recall).unwrap().0) - .assert_approx_eq(&Data::from([0.908_688, 1.678_973, 4.216_837, 9.615_904]), 6) + .assert_approx_eq(&Data::from([0.908_688, 2.247_462, 4.216_837, 9.615_904]), 6) } #[test] @@ -365,7 +365,7 @@ mod tests { let mut rating_stability = HashMap::from([(1, 0.4), (3, 2.3), (4, 10.9)]); let rating_count = HashMap::from([(1, 1), (2, 1), (3, 1), (4, 1)]); let actual = smooth_and_fill(&mut rating_stability, &rating_count).unwrap(); - assert_eq!(actual, [0.4, 0.8052433, 2.3, 10.9,]); + assert_eq!(actual, [0.4, 1.1227008, 2.3, 10.9,]); let mut rating_stability = HashMap::from([(2, 0.35)]); let rating_count = HashMap::from([(2, 1)]);