Skip to content

Commit

Permalink
Merge pull request #235 from MeetThePatel/interpolation
Browse files Browse the repository at this point in the history
Fixed the exponential interpolator
  • Loading branch information
avhz authored Jun 28, 2024
2 parents 0327174 + ab04773 commit 358d97c
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 67 deletions.
142 changes: 77 additions & 65 deletions src/math/interpolation/exponential_interpolator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

//! Module containing functionality for interpolation.

use super::Interpolator;
use crate::math::interpolation::{InterpolationError, InterpolationIndex, InterpolationValue};
use num::Float;

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// STRUCTS & ENUMS
Expand All @@ -18,7 +20,7 @@ use crate::math::interpolation::{InterpolationError, InterpolationIndex, Interpo
/// Exponential Interpolator.
pub struct ExponentialInterpolator<IndexType, ValueType>
where
IndexType: InterpolationIndex,
IndexType: InterpolationIndex<DeltaDiv = ValueType>,
ValueType: InterpolationValue,
{
/// X-axis values for the interpolator.
Expand Down Expand Up @@ -69,88 +71,98 @@ where
}
}

// impl<IndexType, ValueType> Interpolator<IndexType, ValueType>
// for ExponentialInterpolator<IndexType, ValueType>
// where
// IndexType: InterpolationIndex<DeltaDiv = ValueType>, //+ std::ops::Div<Output = ValueType>,
// ValueType: InterpolationValue + num_traits::Float,
// {
// fn fit(&mut self) -> Result<(), InterpolationError> {
// self.fitted = true;
// Ok(())
// }

// fn range(&self) -> (IndexType, IndexType) {
// (*self.xs.first().unwrap(), *self.xs.last().unwrap())
// }
impl<IndexType, ValueType> Interpolator<IndexType, ValueType>
for ExponentialInterpolator<IndexType, ValueType>
where
IndexType: InterpolationIndex<DeltaDiv = ValueType>,
ValueType: InterpolationValue + Float,
{
fn fit(&mut self) -> Result<(), InterpolationError> {
self.fitted = true;
Ok(())
}

// fn add_point(&mut self, point: (IndexType, ValueType)) {
// let idx = self.xs.partition_point(|&x| x < point.0);
fn range(&self) -> (IndexType, IndexType) {
(*self.xs.first().unwrap(), *self.xs.last().unwrap())
}

// self.xs.insert(idx, point.0);
// self.ys.insert(idx, point.1);
// }
fn add_point(&mut self, point: (IndexType, ValueType)) {
let idx = self.xs.partition_point(|&x| x < point.0);

// fn interpolate(&self, point: IndexType) -> Result<ValueType, InterpolationError> {
// let range = self.range();
self.xs.insert(idx, point.0);
self.ys.insert(idx, point.1);
}

// if point.partial_cmp(&range.0).unwrap() == std::cmp::Ordering::Less
// || point.partial_cmp(&range.1).unwrap() == std::cmp::Ordering::Greater
// {
// return Err(InterpolationError::OutsideOfRange);
// }
fn interpolate(&self, point: IndexType) -> Result<ValueType, InterpolationError> {
let range = self.range();

// if let Ok(idx) = self
// .xs
// .binary_search_by(|p| p.partial_cmp(&point).expect("Cannot compare values."))
// {
// return Ok(self.ys[idx]);
// }
if point.partial_cmp(&range.0).unwrap() == std::cmp::Ordering::Less
|| point.partial_cmp(&range.1).unwrap() == std::cmp::Ordering::Greater
{
return Err(InterpolationError::OutsideOfRange);
}

// let idx_r = self.xs.partition_point(|&x| x < point);
// let idx_l = idx_r - 1;
if let Ok(idx) = self
.xs
.binary_search_by(|p| p.partial_cmp(&point).expect("Cannot compare values."))
{
return Ok(self.ys[idx]);
}

// let lambda = (self.xs[idx_r] - point) / (self.xs[idx_r] - self.xs[idx_l]);
let idx_r = self.xs.partition_point(|&x| x < point);

// let exponent_1 = lambda * (point / self.xs[idx_l]);
// let exponent_2 = point / self.xs[idx_r] - lambda * (point / self.xs[idx_r]);
let x_l = self.xs[idx_r - 1];
let y_l = self.ys[idx_r - 1];

// let term_1 = self.ys[idx_l].powf(exponent_1);
// let term_2 = self.ys[idx_r].powf(exponent_2);
let x_r = self.xs[idx_r];
let y_r = self.ys[idx_r];

// let result = term_1 * term_2;
let result = ((y_r.ln() - y_l.ln()) * ((point - x_l) / (x_r - x_l)) + y_l.ln()).exp();

// Ok(result)
// }
// }
Ok(result)
}
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Unit tests
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// #[cfg(test)]
// mod tests_exponential_interpolation {
// use super::*;
// use crate::{assert_approx_equal, RUSTQUANT_EPSILON};
// use time::macros::date;
#[cfg(test)]
mod tests_exponential_interpolation {
use super::*;
use crate::{assert_approx_equal, RUSTQUANT_EPSILON};
use time::macros::date;

#[test]
fn test_exponential_interpolation_numbers() {
let xs = vec![1.0, 2.0, 3.0, 5.0];
let ys = vec![5.0, 25.0, 125.0, 3125.0];

let interpolator = ExponentialInterpolator::new(xs, ys).unwrap();
assert_approx_equal!(
625.0,
interpolator.interpolate(4.0).unwrap(),
RUSTQUANT_EPSILON
);
}

// // #[test]
// // fn test_exponential_interpolation_dates() {
// // let d_1m = date!(1990 - 06 - 16);
// // let d_2m = date!(1990 - 07 - 17);
#[test]
fn test_exponential_interpolation_dates() {
let d_1m = date!(1990 - 06 - 16);
let d_2m = date!(1990 - 07 - 17);

// // let r_1m = 0.9870;
// // let r_2m = 0.9753;
let r_1m = 0.9870;
let r_2m = 0.9753;

// // let dates = vec![d_1m, d_2m];
// // let rates = vec![r_1m, r_2m];
let dates = vec![d_1m, d_2m];
let rates = vec![r_1m, r_2m];

// // let mut interpolator = ExponentialInterpolator::new(dates, rates).unwrap();
let interpolator = ExponentialInterpolator::new(dates, rates).unwrap();

// // assert_approx_equal!(
// // 0.9854,
// // interpolator.interpolate(date!(1990 - 06 - 20)).unwrap(),
// // RUSTQUANT_EPSILON
// // );
// // }
// }
assert_approx_equal!(
0.9854824711068088,
interpolator.interpolate(date!(1990 - 06 - 20)).unwrap(),
RUSTQUANT_EPSILON
);
}
}
4 changes: 2 additions & 2 deletions src/math/interpolation/linear_interpolator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::math::interpolation::{
/// Linear Interpolator.
pub struct LinearInterpolator<IndexType, ValueType>
where
IndexType: InterpolationIndex,
IndexType: InterpolationIndex<DeltaDiv = ValueType>,
ValueType: InterpolationValue,
{
/// X-axis values for the interpolator.
Expand All @@ -39,7 +39,7 @@ where

impl<IndexType, ValueType> LinearInterpolator<IndexType, ValueType>
where
IndexType: InterpolationIndex,
IndexType: InterpolationIndex<DeltaDiv = ValueType>,
ValueType: InterpolationValue,
{
/// Create a new LinearInterpolator.
Expand Down

0 comments on commit 358d97c

Please sign in to comment.