From b6822a7a875996c092b8ac30df98377f00093792 Mon Sep 17 00:00:00 2001 From: Noah Clarkson Date: Mon, 2 Dec 2024 20:36:05 +1300 Subject: [PATCH] Switch connector to async --- Cargo.toml | 7 +- src/config.rs | 10 ++- src/data/candlestick.rs | 149 ++-------------------------------------- src/data/dataset.rs | 119 ++++++-------------------------- src/krypto_account.rs | 0 src/main.rs | 12 ++-- src/util/mod.rs | 1 - src/util/test_util.rs | 57 --------------- tests/algorithm.rs | 55 --------------- 9 files changed, 42 insertions(+), 368 deletions(-) create mode 100644 src/krypto_account.rs delete mode 100644 src/util/test_util.rs delete mode 100644 tests/algorithm.rs diff --git a/Cargo.toml b/Cargo.toml index 25e4c76..3ad247d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] -binance = { git = "https://github.com/wisespace-io/binance-rs.git" } +binance-rs-async = "1.3.3" serde = { version = "1.0", features = ["derive"] } serde_yaml = "0.9" chrono = "0.4" @@ -24,6 +24,5 @@ linfa-pls = "0.7" ndarray = "0.15" derive_builder = "0.20" genevo = "0.7" - -[dev-dependencies] -tempfile = "3.14" \ No newline at end of file +tokio = { version = "1.0", features = ["full"] } +futures = "0.3" \ No newline at end of file diff --git a/src/config.rs b/src/config.rs index 779c938..3b1f67e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -117,7 +117,7 @@ impl KryptoConfig { /// /// A `Result` containing the `KryptoConfig` on success or an `Error` on failure. #[instrument(level = "info", skip(filename))] - pub fn read_config>(filename: Option

) -> Result { + pub async fn read_config>(filename: Option

) -> Result { let path = filename .map(|p| p.as_ref().to_path_buf()) .unwrap_or_else(|| Path::new("config.yml").to_path_buf()); @@ -138,16 +138,14 @@ impl KryptoConfig { let file = File::open(&path)?; let reader = BufReader::new(file); let config: Self = from_reader(reader)?; - let account: Account = config.get_binance(); + let account = config.get_binance::(); if config.api_key.is_some() || config.api_secret.is_some() { - let account_info = account.get_account().map_err(|e| { + let account_info = account.get_account().await.map_err(|e| { error!("Failed to get account info: {}", e); KryptoError::BinanceApiError(e.to_string()) })?; for asset in account_info.balances { - let free = asset.free.parse::().unwrap_or(0.0); - let locked = asset.locked.parse::().unwrap_or(0.0); - if free + locked > 0.0 { + if asset.free + asset.locked > 0.0 { info!( "Asset: {}, Free: {}, Locked: {}", asset.asset, asset.free, asset.locked diff --git a/src/data/candlestick.rs b/src/data/candlestick.rs index cbae497..4c742c0 100644 --- a/src/data/candlestick.rs +++ b/src/data/candlestick.rs @@ -1,4 +1,4 @@ -use binance::model::{KlineSummaries, KlineSummary}; +use binance::rest_model::{KlineSummaries, KlineSummary}; use chrono::{DateTime, TimeZone, Utc}; use derive_builder::Builder; @@ -48,29 +48,11 @@ impl Candlestick { Ok(Self { open_time, close_time, - open: summary.open.parse().map_err(|_| KryptoError::ParseError { - value_name: "open".to_string(), - timestamp: summary.open_time, - })?, - high: summary.high.parse().map_err(|_| KryptoError::ParseError { - value_name: "high".to_string(), - timestamp: summary.open_time, - })?, - low: summary.low.parse().map_err(|_| KryptoError::ParseError { - value_name: "low".to_string(), - timestamp: summary.open_time, - })?, - close: summary.close.parse().map_err(|_| KryptoError::ParseError { - value_name: "close".to_string(), - timestamp: summary.open_time, - })?, - volume: summary - .volume - .parse() - .map_err(|_| KryptoError::ParseError { - value_name: "volume".to_string(), - timestamp: summary.open_time, - })?, + open: summary.open, + high: summary.high, + low: summary.low, + close: summary.close, + volume: summary.volume, }) } @@ -139,121 +121,4 @@ impl PartialOrd for Candlestick { fn partial_cmp(&self, other: &Self) -> Option { self.open_time.partial_cmp(&other.open_time) } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::error::KryptoError; - use chrono::TimeZone; - - #[test] - fn test_candlestick_from_summary() { - let summary = KlineSummary { - open_time: 1618185600000, - close_time: 1618185659999, - open: "0.00000000".to_string(), - high: "0.00000000".to_string(), - low: "0.00000000".to_string(), - close: "0.00000000".to_string(), - volume: "0.00000000".to_string(), - quote_asset_volume: "0.00000000".to_string(), - number_of_trades: 0, - taker_buy_base_asset_volume: "0.00000000".to_string(), - taker_buy_quote_asset_volume: "0.00000000".to_string(), - }; - - let candlestick = Candlestick::from_summary(summary).unwrap(); - assert_eq!( - candlestick.open_time, - Utc.timestamp_millis_opt(1618185600000).single().unwrap() - ); - assert_eq!( - candlestick.close_time, - Utc.timestamp_millis_opt(1618185659999).single().unwrap() - ); - assert_eq!(candlestick.open, 0.0); - assert_eq!(candlestick.high, 0.0); - assert_eq!(candlestick.low, 0.0); - assert_eq!(candlestick.close, 0.0); - assert_eq!(candlestick.volume, 0.0); - } - - #[test] - fn test_candlestick_from_summary_invalid_open_time() { - let summary = KlineSummary { - open_time: 16181856599999998, - close_time: 16181856599999999, - open: "0.00000000".to_string(), - high: "0.00000000".to_string(), - low: "0.00000000".to_string(), - close: "0.00000000".to_string(), - volume: "0.00000000".to_string(), - quote_asset_volume: "0.00000000".to_string(), - number_of_trades: 0, - taker_buy_base_asset_volume: "0.00000000".to_string(), - taker_buy_quote_asset_volume: "0.00000000".to_string(), - }; - - let result = Candlestick::from_summary(summary); - assert!(matches!( - result, - Err(KryptoError::InvalidCandlestickDateTime { - when: When::Open, - timestamp: 16181856599999998 - }) - )); - } - - #[test] - fn test_candlestick_from_summary_invalid_close_time() { - let summary = KlineSummary { - open_time: 1618185600000, - close_time: 16181856599999999, - open: "0.00000000".to_string(), - high: "0.00000000".to_string(), - low: "0.00000000".to_string(), - close: "0.00000000".to_string(), - volume: "0.00000000".to_string(), - quote_asset_volume: "0.00000000".to_string(), - number_of_trades: 0, - taker_buy_base_asset_volume: "0.00000000".to_string(), - taker_buy_quote_asset_volume: "0.00000000".to_string(), - }; - - let result = Candlestick::from_summary(summary); - assert!(matches!( - result, - Err(KryptoError::InvalidCandlestickDateTime { - when: When::Close, - timestamp: 16181856599999999 - }) - )); - } - - #[test] - fn test_candlestick_from_summary_open_time_greater_than_close_time() { - let summary = KlineSummary { - open_time: 1618185659999, - close_time: 1618185600000, - open: "0.00000000".to_string(), - high: "0.00000000".to_string(), - low: "0.00000000".to_string(), - close: "0.00000000".to_string(), - volume: "0.00000000".to_string(), - quote_asset_volume: "0.00000000".to_string(), - number_of_trades: 0, - taker_buy_base_asset_volume: "0.00000000".to_string(), - taker_buy_quote_asset_volume: "0.00000000".to_string(), - }; - - let result = Candlestick::from_summary(summary); - assert!(matches!( - result, - Err(KryptoError::OpenTimeGreaterThanCloseTime { - open_time: 1618185659999, - close_time: 1618185600000 - }) - )); - } -} +} \ No newline at end of file diff --git a/src/data/dataset.rs b/src/data/dataset.rs index 93ad853..6bb9781 100644 --- a/src/data/dataset.rs +++ b/src/data/dataset.rs @@ -29,13 +29,13 @@ impl Dataset { The loaded dataset if successful, or a KryptoError if an error occurred. */ #[instrument(skip(config))] - pub fn load(config: &KryptoConfig) -> Result { + pub async fn load(config: &KryptoConfig) -> Result { let mut interval_data_map = HashMap::new(); let market: Market = config.get_binance(); for interval in &config.intervals { let interval = *interval; - let interval_data = IntervalData::load(&interval, config, &market)?; + let interval_data = IntervalData::load(&interval, config, &market).await?; info!("Loaded data for {}", &interval); interval_data_map.insert(interval, interval_data); } @@ -93,20 +93,22 @@ pub struct IntervalData { impl IntervalData { #[instrument(skip(config, market))] - fn load( + async fn load( interval: &Interval, config: &KryptoConfig, market: &Market, ) -> Result { - let mut symbol_data_map = HashMap::new(); let end = Utc::now().timestamp_millis(); - + let mut tasks = Vec::new(); for symbol in &config.symbols { - let symbol = symbol.clone(); - let symbol_data = RawSymbolData::load(interval, &symbol, end, config, market)?; - info!("Loaded data for {}", &symbol); - symbol_data_map.insert(symbol, symbol_data); + let task = RawSymbolData::load(interval, symbol, end, config, market); + tasks.push(task); } + let result = futures::future::try_join_all(tasks).await?; + let symbol_data_map: HashMap = result + .into_iter() + .map(|data| (data.symbol.clone(), data)) + .collect(); let records = get_records(&symbol_data_map); let normalized_predictors = get_normalized_predictors(records); @@ -280,11 +282,12 @@ struct RawSymbolData { candles: Vec, technicals: Vec, labels: Vec, + symbol: String, } impl RawSymbolData { #[instrument(skip(interval, end, config, market))] - fn load( + async fn load( interval: &Interval, symbol: &str, end: i64, @@ -296,7 +299,7 @@ impl RawSymbolData { let timestamps = get_timestamps(start.timestamp_millis(), end, *interval)?; for (start, end) in timestamps { - let mut chunk = Self::load_chunk(market, symbol, interval, start, end)?; + let mut chunk = Self::load_chunk(market, symbol, interval, start, end).await?; candles.append(&mut chunk); } @@ -311,20 +314,22 @@ impl RawSymbolData { labels.push(percentage_change.signum()); } debug!( - "Loaded {} candles ({} labels | {}x{} technicals)", + "Loaded {} candles ({} labels | {}x{} technicals) for {}", candles.len(), labels.len(), technicals.len(), - technicals[0].as_array().len() + technicals[0].as_array().len(), + symbol ); Ok(Self { candles, technicals, labels, + symbol: symbol.to_string(), }) } - fn load_chunk( + async fn load_chunk( market: &Market, symbol: &str, interval: &Interval, @@ -339,6 +344,7 @@ impl RawSymbolData { Some(start as u64), Some(end as u64), ) + .await .map_err(|e| KryptoError::BinanceApiError(e.to_string()))?; let candlesticks = Candlestick::map_to_candlesticks(summaries)?; Ok(candlesticks) @@ -363,87 +369,4 @@ impl RawSymbolData { fn recompute_technicals(&mut self, technical_names: Vec) { self.technicals = Technicals::get_technicals(&self.candles, technical_names); } -} - -#[cfg(test)] -mod tests { - use tracing::info; - - use crate::{ - config::KryptoConfig, - util::{date_utils::MINS_TO_MILLIS, test_util::setup_default_data}, - }; - - #[test] - #[ignore] - fn test_data_load() { - let _ = setup_default_data("data_load", None); - } - - #[test] - #[ignore] - fn test_data_shape() { - let (dataset, _gaurds) = setup_default_data("data_shape", None); - let shape = dataset.shape(); - info!("{:?}", shape); - assert_eq!((shape.0, shape.1), (2, 2)); - for value in dataset.values() { - let data_lengths = value - .values() - .map(|d| d.get_candles().len()) - .collect::>(); - let technicals_lengths = value - .values() - .map(|d| d.get_technicals().len()) - .collect::>(); - let labels_lengths = value - .values() - .map(|d| d.get_labels().len()) - .collect::>(); - assert!(data_lengths.iter().all(|&x| x == data_lengths[0])); - assert!(technicals_lengths - .iter() - .all(|&x| x == technicals_lengths[0])); - assert!(labels_lengths.iter().all(|&x| x == labels_lengths[0])); - } - } - - #[test] - #[ignore] - fn test_data_times_match() { - let config = KryptoConfig { - start_date: "2021-02-02".to_string(), - symbols: vec![ - "BTCUSDT".to_string(), - "ETHUSDT".to_string(), - "BNBUSDT".to_string(), - "ADAUSDT".to_string(), - "XRPUSDT".to_string(), - ], - ..Default::default() - }; - let (dataset, _gaurds) = setup_default_data("data_times_match", Some(config)); - for (key, value) in dataset.get_map() { - let maximum_variance = key.to_minutes() * MINS_TO_MILLIS / 2; - let symbol_datas = value.values(); - let times = symbol_datas - .map(|d| { - d.get_candles() - .clone() - .iter() - .map(|v| v.close_time) - .collect::>() - }) - .collect::>(); - for i in 0..times[0].len() { - for j in 0..times.len() { - for k in 0..times.len() { - let difference = (times[j][i] - times[k][i]).abs(); - let difference = difference.num_milliseconds(); - assert!(difference <= maximum_variance, "Difference: {}", difference); - } - } - } - } - } -} +} \ No newline at end of file diff --git a/src/krypto_account.rs b/src/krypto_account.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/main.rs b/src/main.rs index 953452e..1dff905 100644 --- a/src/main.rs +++ b/src/main.rs @@ -19,18 +19,20 @@ use krypto::{ }; use tracing::{error, info}; -pub fn main() { +#[tokio::main] +pub async fn main() { let (_, file_guard) = setup_tracing(Some("logs")).expect("Failed to set up tracing"); - let result = run(); + let result = run().await; if let Err(e) = result { error!("Error: {:?}", e); } drop(file_guard); } -fn run() -> Result<(), KryptoError> { - let config = KryptoConfig::read_config::<&str>(None)?; - let dataset = Dataset::load(&config)?; + +async fn run() -> Result<(), KryptoError> { + let config = KryptoConfig::read_config::<&str>(None).await?; + let dataset = Dataset::load(&config).await?; let selection_ratio = 0.7; let num_individuals_per_parents = 2; diff --git a/src/util/mod.rs b/src/util/mod.rs index 5ad9735..23e1d4d 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -1,4 +1,3 @@ pub mod date_utils; pub mod math_utils; pub mod matrix_utils; -pub mod test_util; diff --git a/src/util/test_util.rs b/src/util/test_util.rs deleted file mode 100644 index 405e8de..0000000 --- a/src/util/test_util.rs +++ /dev/null @@ -1,57 +0,0 @@ -use std::path::Path; -use tracing::{info, subscriber::set_default}; -use tracing_appender::non_blocking::WorkerGuard; -use tracing_subscriber::fmt; - -use crate::{config::KryptoConfig, data::dataset::Dataset}; - -pub struct TracingGuards { - _subscriber_guard: tracing::subscriber::DefaultGuard, - _worker_guard: WorkerGuard, -} - -pub fn setup_test_tracing(test_name: &str) -> TracingGuards { - // Create logs directory if it doesn't exist - let log_dir = Path::new("tests/logs"); - if !log_dir.exists() { - std::fs::create_dir_all(log_dir).unwrap(); - } - - // Set up file appender with non-blocking writer - let log_file = format!("tests/logs/{}.log", test_name); - let file_appender = tracing_appender::rolling::never("", &log_file); - let (non_blocking, worker_guard) = tracing_appender::non_blocking(file_appender); - - // Set up subscriber - let subscriber = fmt::Subscriber::builder() - .with_writer(non_blocking) - .with_ansi(false) - .with_level(true) - .with_target(true) - .with_thread_ids(true) - .with_thread_names(true) - .with_max_level(tracing::Level::DEBUG) - .finish(); - - // Set as default subscriber for this thread - let subscriber_guard = set_default(subscriber); - - // Return guards to keep the subscriber and writer alive - TracingGuards { - _subscriber_guard: subscriber_guard, - _worker_guard: worker_guard, - } -} - -pub fn setup_default_data( - test_name: &str, - config: Option, -) -> (Dataset, TracingGuards) { - let guards = setup_test_tracing(test_name); - info!("-----------------"); - info!("Test: {}", test_name); - info!("-----------------"); - let config = config.unwrap_or_default(); - let dataset = Dataset::load(&config).unwrap(); - (dataset, guards) -} diff --git a/tests/algorithm.rs b/tests/algorithm.rs deleted file mode 100644 index 301b217..0000000 --- a/tests/algorithm.rs +++ /dev/null @@ -1,55 +0,0 @@ - -use krypto::{ - algorithm::algo::{Algorithm, AlgorithmSettings}, - config::KryptoConfig, - data::interval::Interval, util::test_util::setup_default_data, -}; -use tracing::info; - -#[test] -#[ignore] -fn test_algorithm() { - let config = KryptoConfig { - start_date: "2017-01-01".to_string(), - symbols: vec!["BTCUSDT".to_string()], - intervals: vec![Interval::OneDay], - cross_validations: 30, - ..Default::default() - }; - let (dataset, _gaurds) = setup_default_data("algorithm", Some(config.clone())); - info!("Shape: {:?}", dataset.shape()); - let interval = dataset.keys().next().unwrap(); - let interval_data = dataset.get(interval).unwrap(); - let symbol = config.symbols[0].clone(); - let settings = AlgorithmSettings::new(3, 3, &symbol); - let result = Algorithm::load(interval_data, settings, &config); - match result { - Ok(_) => { - info!("Algorithm Loaded Successfully"); - } - Err(e) => { - panic!("Error: {}", e); - } - } -} - -#[test] -#[ignore] -fn test_algo_on_all_data() { - let config = KryptoConfig { - start_date: "2019-01-01".to_string(), - symbols: vec!["BTCUSDT".to_string(), "ETHUSDT".to_string(), "BNBUSDT".to_string()], - intervals: vec![Interval::TwoHours], - ..Default::default() - }; - let (dataset, _gaurds) = setup_default_data("algo_on_all_unseen_data", Some(config.clone())); - info!("Shape: {:?}", dataset.shape()); - let interval = dataset.keys().next().unwrap(); - let interval_data = dataset.get(interval).unwrap(); - let symbol = config.symbols[0].clone(); - let settings = AlgorithmSettings::new(10, 18, &symbol); - let result = Algorithm::load(interval_data, settings, &config).unwrap(); - info!("Algorithm Loaded Successfully"); - let algo_result = result.backtest_on_all_seen_data(interval_data, &config).unwrap(); - info!("Algorithm Result: {}", algo_result); -}