Skip to content

Commit

Permalink
Fix test data range and update training data splitting; enhance loggi…
Browse files Browse the repository at this point in the history
…ng in IntervalData

 In `Algorithm`:
  - Changed test results loop from `(0..count)` to `(1..count)` to correct data indexing.
  - Modified training data generation to use only features and labels up to the start index instead of concatenating before and after slices.

- In `IntervalData`:
  - Consolidated multiple debug statements into a single line for clearer and more concise logging.
  • Loading branch information
noahbclarkson committed Nov 27, 2024
1 parent 63db524 commit 93d103a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
7 changes: 3 additions & 4 deletions src/algorithm/algo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,19 @@ impl Algorithm {
let total_size = ds.len()?;
let test_data_size = total_size / count;

let test_results: Vec<TestData> = (0..count)
let test_results: Vec<TestData> = (1..count)
.map(|i| -> Result<TestData, KryptoError> {
let start = i * test_data_size;
let end = match i == count - 1 {
true => total_size,
false => (i + 1) * test_data_size,
};
let features = ds.get_features();
let labels = ds.get_labels();
let candles = ds.get_candles();
let test_features = &features[start..end];
let test_candles = &candles[start..end];
let train_features = [&features[..start], &features[end..]].concat();
let train_labels = [&labels[..start], &labels[end..]].concat();
let train_features = features[..start].to_vec();
let train_labels = ds.get_labels()[..start].to_vec();

let pls = get_pls(train_features, train_labels, settings.n)?;
let predictions = predict(&pls, test_features)?;
Expand Down
4 changes: 1 addition & 3 deletions src/data/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,7 @@ impl IntervalData {
.cloned()
.collect();

debug!("Features shape: {}x{}", features.len(), features[0].len());
debug!("Labels count: {}", labels.len());
debug!("Candles count: {}", candles.len());
debug!("Features: {}x{} | Labels: {} | Candles: {}", features.len(), features[0].len(), labels.len(), candles.len());

SymbolDataset::new(features, labels, candles)
}
Expand Down

0 comments on commit 93d103a

Please sign in to comment.