diff --git a/Cargo.lock b/Cargo.lock index d5bab4a..1927833 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -45,6 +45,15 @@ dependencies = [ "libc", ] +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + [[package]] name = "argminmax" version = "0.6.2" @@ -127,15 +136,15 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.7.1" +version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50" +checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" [[package]] name = "cc" -version = "1.1.18" +version = "1.1.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b62ac837cdb5cb22e10a256099b4fc502b1dfe560cb282963a974d7abd80e476" +checksum = "07b1695e2c7e8fc85310cde85aeaab7e3097f593c91d209d3f9df76c928100f0" dependencies = [ "jobserver", "libc", @@ -367,9 +376,9 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.60" +version = "0.1.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -442,9 +451,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.158" +version = "0.2.159" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" +checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" [[package]] name = "libm" @@ -470,19 +479,18 @@ checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "lz4" -version = "1.26.0" +version = "1.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "958b4caa893816eea05507c20cfe47574a43d9a697138a7872990bba8a0ece68" +checksum = "a231296ca742e418c43660cb68e082486ff2538e8db432bc818580f3965025ed" dependencies = [ - "libc", "lz4-sys", ] [[package]] name = "lz4-sys" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "109de74d5d2353660401699a4174a4ff23fcc649caf553df71933c7fb45ad868" +checksum = "fcb44a01837a858d47e5a630d2ccf304c8efcc4b83b8f9f75b7a9ee4fcc6e57d" dependencies = [ "cc", "libc", @@ -716,9 +724,9 @@ dependencies = [ [[package]] name = "pkg-config" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" +checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" [[package]] name = "planus" @@ -1123,6 +1131,7 @@ dependencies = [ name = "polars-trading" version = "0.1.0" dependencies = [ + "approx", "num", "polars", "pyo3", @@ -1155,9 +1164,9 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" +checksum = "d30538d42559de6b034bc76fd6dd4c38961b1ee5c6c56e3808c50128fdbc22ce" [[package]] name = "ppv-lite86" @@ -1188,9 +1197,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.22.2" +version = "0.22.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433" +checksum = "15ee168e30649f7f234c3d49ef5a7a6cbf5134289bc46c29ff3155fa3221c225" dependencies = [ "cfg-if", "indoc", @@ -1206,9 +1215,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.22.2" +version = "0.22.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8" +checksum = "e61cef80755fe9e46bb8a0b8f20752ca7676dcc07a5277d8b7768c6172e529b3" dependencies = [ "once_cell", "target-lexicon", @@ -1216,9 +1225,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.22.2" +version = "0.22.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6" +checksum = "67ce096073ec5405f5ee2b8b31f03a68e02aa10d5d4f565eca04acc41931fa1c" dependencies = [ "libc", "pyo3-build-config", @@ -1226,9 +1235,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.22.2" +version = "0.22.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206" +checksum = "2440c6d12bc8f3ae39f1e775266fa5122fd0c8891ce7520fa6048e683ad3de28" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -1238,9 +1247,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.22.2" +version = "0.22.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372" +checksum = "1be962f0e06da8f8465729ea2cb71a416d2257dff56cbe40a70d3e62a93ae5d1" dependencies = [ "heck", "proc-macro2", @@ -1382,9 +1391,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.3" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4" +checksum = "355ae415ccd3a04315d3f8246e86d67689ea74d88d915576e1589a351062a13b" dependencies = [ "bitflags", ] @@ -1489,9 +1498,9 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "simdutf8" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a" +checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" [[package]] name = "siphasher" @@ -1633,18 +1642,18 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "thiserror" -version = "1.0.63" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" +checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.63" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" +checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" dependencies = [ "proc-macro2", "quote", @@ -1653,9 +1662,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.12" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" [[package]] name = "unicode-reverse" @@ -1668,15 +1677,15 @@ dependencies = [ [[package]] name = "unicode-segmentation" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" [[package]] name = "unicode-width" -version = "0.1.13" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" [[package]] name = "unindent" diff --git a/Cargo.toml b/Cargo.toml index bc30845..9889a0f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,3 +13,6 @@ pyo3-polars = { version = "0.16.0", features = ["derive", "dtype-struct"] } serde = { version = "1", features = ["derive"] } polars = { version = "0.42.0", features = ["dtype-struct"] } num = "0.4.3" + +[dev-dependencies] +approx = "0.5.1" diff --git a/src/bars.rs b/src/bars.rs index 447366c..12e6a1e 100644 --- a/src/bars.rs +++ b/src/bars.rs @@ -42,8 +42,7 @@ where } #[derive(Deserialize)] -struct BarGroupKwargs -{ +struct BarGroupKwargs { bar_size: f64, } diff --git a/src/labels.rs b/src/labels.rs index 53760d4..fb3a9f7 100644 --- a/src/labels.rs +++ b/src/labels.rs @@ -1,64 +1,545 @@ -// #![allow(clippy::unused_unit)] -// use std::cmp::PartialOrd; - -// use polars::prelude::*; -// use pyo3_polars::derive::polars_expr; -// use serde::Deserialize; - -// fn apply_profit_taking_stop_loss( -// index: &ChunkedArray, -// prices: &Float64Chunked, -// profit_taking: &Float64Chunked, -// stop_loss: &Float64Chunked, -// ) -> (Option, Option) -// where -// T: PartialOrd + Clone, -// { -// let returns: Vec = prices -// .iter() -// .map(|x| x.unwrap() / prices.get(0).unwrap() - 1.0) -// .collect(); -// // Get the minimum index where profit take is greater than returns -// let profit_taking_index = returns -// .iter() -// .zip(profit_taking.iter()) -// .position(|(&ret, &pt)| ret >= pt); -// let stop_loss_index = returns -// .iter() -// .zip(stop_loss.iter()) -// .position(|(&ret, &sl)| ret <= sl); - -// match (profit_taking_index, stop_loss_index) { -// (Some(pt), Some(sl)) => { -// return ( -// Some(index.get(pt).unwrap().clone()), -// Some(index.get(sl).unwrap().clone()), -// ) -// }, -// (Some(pt), None) => return (Some(index.get(pt).unwrap().clone()), None), -// (None, Some(sl)) => return (None, Some(index.get(sl).unwrap().clone())), -// (None, None) => return (None, None), -// } -// } - -// fn barrier_touch_struct(input_fields: &[Field]) -> PolarsResult { -// let dtype = input_fields[0].data_type(); -// Ok(Field::new( -// input_fields[0].name(), -// DataType::Struct(vec![ -// Field::new("barrier_touch_start", dtype.clone()), -// Field::new("barrier_touch_profit_take", dtype.clone()), -// Field::new("barrier_touch_stop_loss", dtype.clone()), -// Field::new("barrier_touch_vertical_barrier", dtype.clone()), -// ]), -// )) -// } - -// #[polars_expr(output_type_func=barrier_touch_struct)] -// fn get_barrier_touches(inputs: &[Series]) -> PolarsResult { -// let targets = inputs[0].datetime()?; // Not sure what to do with this type yet. -// let prices = inputs[1].f64()?; -// let profit_taking = inputs[2].f64()?; -// let stop_loss = inputs[3].f64()?; -// let (pt, sl) = apply_profit_taking_stop_loss(targets, prices, profit_taking, stop_loss); -// } +#![allow(clippy::unused_unit)] +/// TODOS: +/// - [ ] Add bitmask +/// - [ ] Handle 0 size price paths +/// - [ ] Calculate barrier touch from index +use polars::prelude::*; +use pyo3_polars::derive::polars_expr; +use serde::Deserialize; + +/// Returns the start and end indices of a slice range within a vector of i64 values. +/// +/// # Arguments +/// +/// * `data` - A vector of i64 values to search within. +/// * `start` - The value to search for as the start of the range. +/// * `end` - The value to search for as the end of the range. +/// +/// # Returns +/// +/// * `Ok((usize, usize))` - A tuple containing the start and end indices if both are found. +/// * `Err(String)` - An error message if either the start or end value (or both) are not found in the data. +/// +/// # Examples +/// +/// ``` +/// let data = vec![1, 2, 3, 4, 5]; +/// assert_eq!(get_slice_range(data, 2, 4), Ok((1, 3))); +/// ``` +fn get_slice_range(data: &Vec, start: i64, end: i64) -> Result<(usize, usize), String> { + let start_idx = data.iter().position(|&r| r == start); + let end_idx = data.iter().position(|&r| r == end); + match (start_idx, end_idx) { + (Some(start_idx), Some(end_idx)) => Ok((start_idx, end_idx)), + (Some(_), None) => Err(format!("End index {} not found in index", end).into()), + (None, Some(_)) => Err(format!("Start index {} not found in index", start).into()), + (None, None) => Err(format!( + "Both start index {} and end index {} not found in index", + start, end + ) + .into()), + } +} + +/// Calculate the returns of a given price path +/// +/// I do this slightly differently than Lopez de Prado. In AFML pg. 46, he calculates +/// the returns by setting the first price to the price before the price path. This +/// seems a little off to me, since it means the first price in your price path does +/// not have a 0 return. This means when you use this label to train a model, you have +/// to be careful to not use the data from the date of the label. I prefer to set the +/// returns so the first return in the price path is 0. This way, you can use all the +/// data up to the close price of the date of the label. +/// +/// # Arguments +/// +/// * `prices` - A vector of prices to calculate the returns of. +/// +/// # Returns +/// +/// * `Vec` - A vector of returns for the given price path. +/// +/// # Examples +/// +/// ``` +/// let prices = vec![1.0, 2.0, 3.0]; +/// assert_eq!(calculate_price_path_return(prices), vec![Some(0.0), Some(1.0), Some(0.5)]); +/// ``` +fn calculate_price_path_return(prices: Vec) -> Vec { + let first_price = prices[0]; + prices.iter().map(|x| x / first_price - 1.0).collect() +} + +struct TripleBarrierLabel { + ret: f64, + label: i64, + barrier_touch: i64, +} + +/// Calculate the label for a given price path +fn get_label( + returns: &[f64], + profit_taking: Option, + stop_loss: Option, + zero_vertical_barrier: bool, +) -> TripleBarrierLabel { + let pt_touch_idx = match profit_taking { + Some(pt) => returns.iter().position(|&r| r >= pt), + None => None, + }; + let sl_touch_idx = match stop_loss { + Some(sl) => returns.iter().position(|&r| r <= sl), + None => None, + }; + match (pt_touch_idx, sl_touch_idx) { + (Some(pt_touch_idx), Some(sl_touch_idx)) => { + if pt_touch_idx < sl_touch_idx { + TripleBarrierLabel { + ret: returns[pt_touch_idx], + label: 1, + barrier_touch: pt_touch_idx as i64, + } + } else { + TripleBarrierLabel { + ret: returns[sl_touch_idx], + label: -1, + barrier_touch: sl_touch_idx as i64, + } + } + }, + (Some(pt_touch_idx), None) => TripleBarrierLabel { + ret: returns[pt_touch_idx], + label: 1, + barrier_touch: pt_touch_idx as i64, + }, + (None, Some(sl_touch_idx)) => TripleBarrierLabel { + ret: returns[sl_touch_idx], + label: -1, + barrier_touch: sl_touch_idx as i64, + }, + (None, None) => { + if zero_vertical_barrier { + TripleBarrierLabel { + ret: returns[returns.len() - 1], + label: 0, + barrier_touch: (returns.len() - 1) as i64, + } + } else { + TripleBarrierLabel { + ret: returns[returns.len() - 1], + label: returns[returns.len() - 1].signum() as i64, + barrier_touch: (returns.len() - 1) as i64, + } + } + }, + } +} + +struct TripleBarrierLabels { + rets: Vec, + labels: Vec, + barrier_touches: Vec, +} + +impl TripleBarrierLabels { + fn new() -> Self { + TripleBarrierLabels { + rets: Vec::new(), + labels: Vec::new(), + barrier_touches: Vec::new(), + } + } + fn new_with_capacity(capacity: usize) -> Self { + TripleBarrierLabels { + rets: Vec::with_capacity(capacity), + labels: Vec::with_capacity(capacity), + barrier_touches: Vec::with_capacity(capacity), + } + } +} + +fn calculate_labels( + index: Vec, + prices: Vec, + profit_taking: Vec>, + stop_loss: Vec>, + vertical_barriers: Vec>, + zero_vertical_barrier: bool, +) -> TripleBarrierLabels { + let mut labels = TripleBarrierLabels::new_with_capacity(prices.len()); + + for i in 0..index.len() { + let price_path = match vertical_barriers[i] { + Some(vb) => { + let (start_idx, end_idx) = get_slice_range(&index, index[i], vb).unwrap(); + calculate_price_path_return(prices[start_idx..end_idx].into()) + }, + None => calculate_price_path_return(prices[i..].into()), + }; + let label = get_label( + &price_path, + profit_taking[i], + stop_loss[i], + zero_vertical_barrier, + ); + labels.rets.push(label.ret); + labels.labels.push(label.label); + labels.barrier_touches.push(label.barrier_touch); + } + labels +} + +fn triple_barrier_struct(input_fields: &[Field]) -> PolarsResult { + Ok(Field::new( + "triple_barrier_label".into(), + DataType::Struct(vec![ + Field::new("price_path_return", DataType::Float64), + Field::new("price_path_label", DataType::Float64), + Field::new("barrier_touch", DataType::Int64), + ]), + )) +} + +// TODO: Provide bitmask for price paths to ignore +#[polars_expr(output_type_func=triple_barrier_struct)] +fn triple_barrier_label(inputs: &[Series]) -> PolarsResult { + // There should be no nulls in index + let index = &inputs[0]; + let index = if index.null_count() == 0 { + index.i64()?.to_vec_null_aware().left().unwrap() + } else { + return Err(PolarsError::InvalidOperation( + "Index should not contain null values".into(), + )); + }; + // There should be no null prices + let prices = &inputs[1]; + let prices = if prices.null_count() == 0 { + prices.f64()?.to_vec_null_aware().left().unwrap() + } else { + return Err(PolarsError::InvalidOperation( + "Prices should not contain null values".into(), + )); + }; + // Null price taking means we don't implement + let price_taking = inputs[2].f64()?.to_vec(); + // Null stop loss means we don't implement + let stop_loss = inputs[3].f64()?.to_vec(); + // Null vertical barrier means we don't implement + let vertical_barrier = inputs[4].i64()?.to_vec(); + let labels = calculate_labels( + index, + prices, + price_taking, + stop_loss, + vertical_barrier, + false, + ); + + // TODO + let ret_series = Float64Chunked::from_vec("ret", labels.rets); + let label_series = Int64Chunked::from_vec("label", labels.labels); + let barrier_touch_series = Int64Chunked::from_vec("barrier_touch", labels.barrier_touches); + let fields = vec![ + ret_series.into_series(), + label_series.into_series(), + barrier_touch_series.into_series(), + ]; + let struct_series = StructChunked::from_series("row_groups", &fields).unwrap(); + Ok(struct_series.into_series()) +} + +#[cfg(test)] +mod tests { + use super::*; + + // Tests for get_slice_range function + #[test] + fn test_get_slice_range_normal() { + let data = vec![1, 2, 3, 4, 5]; + assert_eq!(get_slice_range(&data, 2, 4), Ok((1, 3))); + } + + #[test] + fn test_get_slice_range_same_start_end() { + let data = vec![1, 2, 3, 4, 5]; + assert_eq!(get_slice_range(&data, 3, 3), Ok((2, 2))); + } + + #[test] + fn test_get_slice_range_start_not_found() { + let data = vec![1, 2, 3, 4, 5]; + assert_eq!( + get_slice_range(&data, 0, 4), + Err("Start index 0 not found in index".to_string()) + ); + } + + #[test] + fn test_get_slice_range_end_not_found() { + let data = vec![1, 2, 3, 4, 5]; + assert_eq!( + get_slice_range(&data, 2, 6), + Err("End index 6 not found in index".to_string()) + ); + } + + #[test] + fn test_get_slice_range_both_not_found() { + let data = vec![1, 2, 3, 4, 5]; + assert_eq!( + get_slice_range(&data, 0, 6), + Err("Both start index 0 and end index 6 not found in index".to_string()) + ); + } + + #[test] + fn test_get_slice_range_empty_vector() { + let data: Vec = vec![]; + assert_eq!( + get_slice_range(&data, 1, 2), + Err("Both start index 1 and end index 2 not found in index".to_string()) + ); + } + + // Tests for calculate_price_path_return function + #[test] + fn test_calculate_price_path_return_normal() { + let prices = vec![1.0, 2.0, 3.0]; + assert_eq!(calculate_price_path_return(prices), vec![0.0, 1.0, 2.0]); + } + + #[test] + fn test_calculate_price_path_return_single_price() { + let prices = vec![1.0]; + assert_eq!(calculate_price_path_return(prices), vec![0.0]); + } + + #[test] + fn test_calculate_price_path_return_decreasing_prices() { + use approx::assert_relative_eq; + let prices = vec![3.0, 2.0, 1.0]; + let result = calculate_price_path_return(prices); + let expected = vec![0.0, -1.0 / 3.0, -2.0 / 3.0]; + for (r, e) in result.iter().zip(expected.iter()) { + assert_relative_eq!(r, e, max_relative = 1e-5); + } + } + + // Tests for get_label function + #[test] + fn test_get_label_profit_taking() { + let returns = vec![0.0, 0.1, 0.2, 0.3]; + let label = get_label(&returns, Some(0.25), Some(-0.1), false); + assert_eq!(label.label, 1); + assert_eq!(label.barrier_touch, 3); + assert_eq!(label.ret, 0.3); + } + + #[test] + fn test_get_label_stop_loss() { + let returns = vec![0.0, -0.05, -0.1, -0.15]; + let label = get_label(&returns, Some(0.2), Some(-0.1), false); + assert_eq!(label.label, -1); + assert_eq!(label.barrier_touch, 2); + assert_eq!(label.ret, -0.1); + } + + #[test] + fn test_get_label_no_barrier_touch_zero_vertical() { + let returns = vec![0.0, 0.05, 0.08, 0.09]; + let label = get_label(&returns, Some(0.1), Some(-0.1), true); + assert_eq!(label.label, 0); + assert_eq!(label.barrier_touch, 3); + assert_eq!(label.ret, 0.09); + } + + #[test] + fn test_get_label_no_barrier_touch_non_zero_vertical() { + let returns = vec![0.0, 0.05, 0.08, 0.09]; + let label = get_label(&returns, Some(0.1), Some(-0.1), false); + assert_eq!(label.label, 1); + assert_eq!(label.barrier_touch, 3); + assert_eq!(label.ret, 0.09); + } + + #[test] + fn test_get_label_only_profit_taking() { + let returns = vec![0.0, 0.1, 0.2, 0.3]; + let label = get_label(&returns, Some(0.25), None, false); + assert_eq!(label.label, 1); + assert_eq!(label.barrier_touch, 3); + assert_eq!(label.ret, 0.3); + } + + #[test] + fn test_get_label_only_stop_loss() { + let returns = vec![0.0, -0.05, -0.1, -0.15]; + let label = get_label(&returns, None, Some(-0.1), false); + assert_eq!(label.label, -1); + assert_eq!(label.barrier_touch, 2); + assert_eq!(label.ret, -0.1); + } + + #[test] + fn test_get_label_no_barriers() { + let returns = vec![0.0, 0.05, -0.05, 0.1]; + let label = get_label(&returns, None, None, false); + assert_eq!(label.label, 1); + assert_eq!(label.barrier_touch, 3); + assert_eq!(label.ret, 0.1); + } + + #[test] + fn test_get_label_touches_pt_then_sl() { + let returns = vec![0.0, 0.1, -0.1, -0.15]; + let label = get_label(&returns, Some(0.1), Some(-0.1), false); + assert_eq!(label.label, 1); + assert_eq!(label.barrier_touch, 1); + assert_eq!(label.ret, 0.1); + } + + #[test] + fn test_get_label_touches_sl_then_pt() { + let returns = vec![0.0, -0.1, 0.1, -0.15]; + let label = get_label(&returns, Some(0.1), Some(-0.1), false); + assert_eq!(label.label, -1); + assert_eq!(label.barrier_touch, 1); + assert_eq!(label.ret, -0.1); + } + + #[test] + fn test_calculate_labels_basic() { + let index = vec![1, 2, 3, 4, 5]; + let prices = vec![100.0, 101.0, 102.0, 103.0, 104.0]; + let profit_taking = vec![Some(0.02); 5]; + let stop_loss = vec![Some(-0.01); 5]; + let vertical_barriers = vec![Some(5), Some(5), Some(5), None, None]; + let zero_vertical_barrier = false; + + let result = calculate_labels( + index, + prices, + profit_taking, + stop_loss, + vertical_barriers, + zero_vertical_barrier, + ); + + assert_eq!(result.rets, vec![0.02, 0.02, 0.02, 0.01, 0.0]); + assert_eq!(result.labels, vec![1, 1, 1, 1, 1]); + assert_eq!(result.barrier_touches, vec![2, 1, 0, 0, 0]); + } + + #[test] + fn test_calculate_labels_with_zero_vertical_barrier() { + let index = vec![1, 2, 3, 4, 5]; + let prices = vec![100.0, 99.0, 98.0, 97.0, 96.0]; + let profit_taking = vec![Some(0.02); 5]; + let stop_loss = vec![Some(-0.01); 5]; + let vertical_barriers = vec![Some(5); 5]; + let zero_vertical_barrier = true; + + let result = calculate_labels( + index, + prices, + profit_taking, + stop_loss, + vertical_barriers, + zero_vertical_barrier, + ); + + assert_eq!(result.rets, vec![-0.01, -0.01, -0.01, -0.01, -0.04]); + assert_eq!(result.labels, vec![-1, -1, -1, -1, 0]); + assert_eq!(result.barrier_touches, vec![1, 0, 0, 0, 4]); + } + + #[test] + fn test_calculate_labels_without_vertical_barriers() { + let index = vec![1, 2, 3, 4, 5]; + let prices = vec![100.0, 102.0, 104.0, 106.0, 108.0]; + let profit_taking = vec![Some(0.05); 5]; + let stop_loss = vec![Some(-0.03); 5]; + let vertical_barriers = vec![None; 5]; + let zero_vertical_barrier = false; + + let result = calculate_labels( + index, + prices, + profit_taking, + stop_loss, + vertical_barriers, + zero_vertical_barrier, + ); + + assert_eq!(result.rets, vec![0.05, 0.05, 0.05, 0.05, 0.08]); + assert_eq!(result.labels, vec![1, 1, 1, 1, 1]); + assert_eq!(result.barrier_touches, vec![2, 1, 0, 0, 4]); + } + + #[test] + fn test_calculate_labels_mixed_scenarios() { + let index = vec![1, 2, 3, 4, 5]; + let prices = vec![100.0, 99.0, 101.0, 98.0, 102.0]; + let profit_taking = vec![Some(0.02), None, Some(0.03), Some(0.01), None]; + let stop_loss = vec![Some(-0.01), Some(-0.02), None, None, Some(-0.03)]; + let vertical_barriers = vec![Some(3), Some(4), None, Some(5), None]; + let zero_vertical_barrier = false; + + let result = calculate_labels( + index, + prices, + profit_taking, + stop_loss, + vertical_barriers, + zero_vertical_barrier, + ); + + assert_eq!(result.rets, vec![-0.01, -0.01, 0.01, -0.02, 0.02]); + assert_eq!(result.labels, vec![-1, -1, 1, -1, 1]); + assert_eq!(result.barrier_touches, vec![1, 1, 0, 0, 1]); + } + + #[test] + fn test_calculate_labels_edge_cases() { + let index = vec![1]; + let prices = vec![100.0]; + let profit_taking = vec![None]; + let stop_loss = vec![None]; + let vertical_barriers = vec![None]; + let zero_vertical_barrier = true; + + let result = calculate_labels( + index, + prices, + profit_taking, + stop_loss, + vertical_barriers, + zero_vertical_barrier, + ); + + assert_eq!(result.rets, vec![0.0]); + assert_eq!(result.labels, vec![0]); + assert_eq!(result.barrier_touches, vec![0]); + } + + #[test] + #[should_panic] + fn test_calculate_labels_mismatched_input_lengths() { + let index = vec![1, 2, 3]; + let prices = vec![100.0, 101.0]; + let profit_taking = vec![Some(0.02); 3]; + let stop_loss = vec![Some(-0.01); 3]; + let vertical_barriers = vec![Some(3); 3]; + let zero_vertical_barrier = false; + + calculate_labels( + index, + prices, + profit_taking, + stop_loss, + vertical_barriers, + zero_vertical_barrier, + ); + } +}