Skip to content

Commit

Permalink
Remove Rayon, add error handling for PLS regression, and enhance docu…
Browse files Browse the repository at this point in the history
…mentation

- **Removed Rayon dependency**: Eliminated the use of Rayon for parallel iterations in the algorithm code to simplify dependencies and execution flow.
- **Added error handling for PLS regression**: Wrapped the `PlsRegression::fit` method with `panic::catch_unwind` to catch panics during model fitting and convert them into `KryptoError::PlsError`.
- **Enhanced documentation**: Added Rust doc comments (`///` and `/** ... */`) to public functions and structs for better code readability and generated documentation.
- **Refactored `IntervalData`**: Moved the computation of normalized predictors into the `IntervalData` struct to improve efficiency by precomputing these values.
- **Adjusted test parameters**: Updated test configurations in `tests/algorithm.rs` for more comprehensive testing.
- **Extended error handling**: Introduced a new error variant `KryptoError::PlsError` to handle PLS regression-specific errors.
  • Loading branch information
noahbclarkson committed Nov 22, 2024
1 parent ce2ea64 commit 63db524
Show file tree
Hide file tree
Showing 15 changed files with 278 additions and 72 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ linfa = "0.7"
linfa-pls = "0.7"
ndarray = "0.15"
derive_builder = "0.20"
rayon = "1.10"

[dev-dependencies]
tempfile = "3.14"
61 changes: 45 additions & 16 deletions src/algorithm/algo.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::fmt;

use linfa_pls::PlsRegression;
use rayon::prelude::*;
use tracing::{debug, info, instrument};

use crate::{
Expand All @@ -22,6 +21,17 @@ pub struct Algorithm {
}

impl Algorithm {
/**
Load the algorithm with the given settings and interval dataset.
## Arguments
* `interval_dataset` - The interval dataset to use for training and testing the algorithm.
* `settings` - The settings to use for the algorithm.
* `config` - The configuration to use for the algorithm.
## Returns
The loaded algorithm.
*/
#[instrument(skip(interval_dataset, config))]
pub fn load(
interval_dataset: &IntervalData,
Expand All @@ -42,6 +52,17 @@ impl Algorithm {
})
}

/**
Run a backtest on the given interval dataset with the given settings and configuration.
## Arguments
* `interval_dataset` - The interval dataset to use for training and testing the algorithm.
* `settings` - The settings to use for the algorithm.
* `config` - The configuration to use for the algorithm.
## Returns
The result of the backtest.
*/
fn backtest(
interval_dataset: &IntervalData,
settings: &AlgorithmSettings,
Expand All @@ -55,7 +76,6 @@ impl Algorithm {
let test_data_size = total_size / count;

let test_results: Vec<TestData> = (0..count)
.into_par_iter()
.map(|i| -> Result<TestData, KryptoError> {
let start = i * test_data_size;
let end = match i == count - 1 {
Expand Down Expand Up @@ -85,25 +105,23 @@ impl Algorithm {
})
.collect::<Result<Vec<_>, KryptoError>>()?;

let median_return = median(
&test_results
.iter()
.map(|d| d.monthly_return)
.filter(|&v| v.is_finite())
.collect::<Vec<_>>(),
);
let median_accuracy = median(
&test_results
.iter()
.map(|d| d.accuracy)
.filter(|&v| v.is_finite())
.collect::<Vec<_>>(),
);
let median_return = median(&TestData::get_monthly_returns(&test_results));
let median_accuracy = median(&TestData::get_accuracies(&test_results));
let result = AlgorithmResult::new(median_return, median_accuracy);
info!("Backtest result: {}", result);
Ok(result)
}

/**
Run a backtest on all seen data.
## Arguments
* `interval_dataset` - The interval dataset to use for training and testing the algorithm.
* `config` - The configuration to use for the algorithm.
## Returns
The result of the backtest.
*/
#[instrument(skip(interval_dataset, config, self))]
pub fn backtest_on_all_seen_data(
&self,
Expand Down Expand Up @@ -171,6 +189,17 @@ impl AlgorithmSettings {
}
}

/**
Generate all possible algorithm settings for the given symbols, max_n, and max_depth.
## Arguments
* `symbols` - The symbols to generate settings for.
* `max_n` - The maximum number of components to use.
* `max_depth` - The maximum depth to use.
## Returns
A vector of all possible algorithm settings.
*/
pub fn all(symbols: Vec<String>, max_n: usize, max_depth: usize) -> Vec<Self> {
symbols
.iter()
Expand Down
7 changes: 5 additions & 2 deletions src/algorithm/pls.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::panic;

use linfa::traits::{Fit, Predict as _};
use linfa_pls::PlsRegression;
use ndarray::Array2;
Expand All @@ -14,8 +16,9 @@ pub fn get_pls(
let predictors: Array2<f64> = Array2::from_shape_vec(shape, flattened_predictors)?;
let target: Array2<f64> = Array2::from_shape_vec((target.len(), 1), target)?;
let ds = linfa::dataset::Dataset::new(predictors, target);
let pls = PlsRegression::params(n).fit(&ds)?;
Ok(pls)
let pls = panic::catch_unwind(|| PlsRegression::params(n).fit(&ds))
.map_err(|e| KryptoError::PlsError(format!("{:?}", e)))?;
Ok(pls?)
}

pub fn predict(pls: &PlsRegression<f64>, features: &[Vec<f64>]) -> Result<Vec<f64>, KryptoError> {
Expand Down
28 changes: 27 additions & 1 deletion src/algorithm/test_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ pub struct TestData {
}

impl TestData {
/**
Create a new test data instance from the given predictions and candles.
This will simulate trading based on the predictions and candles.
## Arguments
* `predictions` - The predictions to use for trading.
* `candles` - The candles to use for trading.
* `config` - The configuration to use for trading.
## Returns
A Result containing the test data if successful, or a KryptoError if an error occurred.
*/
pub fn new(
predictions: Vec<f64>,
candles: &[Candlestick],
Expand Down Expand Up @@ -71,7 +83,7 @@ impl TestData {
0 => 0.5,
_ => inner.correct as f64 / total_trades as f64,
};
let monthly_return = if months > 0.0 && inner.cash.is_finite() && inner.cash > 0.0 && inner.cash_history.len() > 1 {
let monthly_return = if months > 0.0 && inner.cash.is_finite() && inner.cash > 0.0 {
(inner.cash / 1000.0).powf(1.0 / months) - 1.0
} else {
0.0
Expand All @@ -83,6 +95,20 @@ impl TestData {
monthly_return,
})
}

pub fn get_accuracies(data: &[Self]) -> Vec<f64> {
data.iter()
.map(|d| d.accuracy)
.filter(|&v| v.is_finite())
.collect()
}

pub fn get_monthly_returns(data: &[Self]) -> Vec<f64> {
data.iter()
.map(|d| d.monthly_return)
.filter(|&v| v.is_finite())
.collect()
}
}

impl fmt::Display for TestData {
Expand Down
8 changes: 2 additions & 6 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ impl KryptoConfig {

let file = File::open(&path)?;
let reader = BufReader::new(file);
let config: Self =
from_reader(reader)?;
let config: Self = from_reader(reader)?;
let account: Account = config.get_binance();
if config.api_key.is_some() || config.api_secret.is_some() {
let account_info = account.get_account().map_err(|e| {
Expand Down Expand Up @@ -420,10 +419,7 @@ cross-validations: 25
let config_result = KryptoConfig::read_config(Some(temp_file.path()));

// Verify that deserialization fails because intervals is a required field
assert!(matches!(
config_result,
Err(KryptoError::SerdeYamlError(_))
));
assert!(matches!(config_result, Err(KryptoError::SerdeYamlError(_))));
}

#[test]
Expand Down
Loading

0 comments on commit 63db524

Please sign in to comment.