From 27c0819a6e71013c7bfac98c0e2dee51274cfdb0 Mon Sep 17 00:00:00 2001 From: Max Kalashnikoff Date: Thu, 2 Jan 2025 23:54:18 +0100 Subject: [PATCH] feat(balances): adding balance providers weights and retrying --- src/analytics/balance_lookup_info.rs | 9 +- src/env/mod.rs | 15 ++- src/env/solscan.rs | 39 ++++++ src/env/zerion.rs | 40 ++++++ src/error.rs | 11 ++ src/handlers/balance.rs | 45 ++++--- src/handlers/proxy.rs | 4 +- src/handlers/supported_chains.rs | 2 +- src/lib.rs | 75 ++++++++---- src/providers/mod.rs | 174 ++++++++++++++++++++++----- src/providers/solscan.rs | 14 ++- src/providers/weights.rs | 6 +- src/providers/zerion.rs | 15 ++- 13 files changed, 357 insertions(+), 92 deletions(-) create mode 100644 src/env/solscan.rs create mode 100644 src/env/zerion.rs diff --git a/src/analytics/balance_lookup_info.rs b/src/analytics/balance_lookup_info.rs index be1eec2be..b86440063 100644 --- a/src/analytics/balance_lookup_info.rs +++ b/src/analytics/balance_lookup_info.rs @@ -1,14 +1,9 @@ -use { - parquet_derive::ParquetRecordWriter, - serde::Serialize, - std::{sync::Arc, time::Duration}, -}; +use {parquet_derive::ParquetRecordWriter, serde::Serialize, std::sync::Arc}; #[derive(Debug, Clone, Serialize, ParquetRecordWriter)] #[serde(rename_all = "camelCase")] pub struct BalanceLookupInfo { pub timestamp: chrono::NaiveDateTime, - pub latency_secs: f64, pub symbol: String, pub implementation_chain_id: String, @@ -29,7 +24,6 @@ pub struct BalanceLookupInfo { impl BalanceLookupInfo { #[allow(clippy::too_many_arguments)] pub fn new( - latency: Duration, symbol: String, implementation_chain_id: String, quantity: String, @@ -45,7 +39,6 @@ impl BalanceLookupInfo { ) -> Self { Self { timestamp: wc::analytics::time::now(), - latency_secs: latency.as_secs_f64(), symbol, implementation_chain_id, quantity, diff --git a/src/env/mod.rs b/src/env/mod.rs index e8d5a0524..9d3019956 100644 --- a/src/env/mod.rs +++ b/src/env/mod.rs @@ -8,15 +8,15 @@ use { project::{storage::Config as StorageConfig, Config as RegistryConfig}, providers::{ProviderKind, ProvidersConfig, Weight}, storage::irn::Config as IrnConfig, - utils::rate_limit::RateLimitingConfig, + utils::{crypto::CaipNamespaces, rate_limit::RateLimitingConfig}, }, serde::de::DeserializeOwned, std::{collections::HashMap, fmt::Display}, }; pub use { arbitrum::*, aurora::*, base::*, berachain::*, binance::*, getblock::*, infura::*, lava::*, - mantle::*, morph::*, near::*, pokt::*, publicnode::*, quicknode::*, server::*, unichain::*, - zksync::*, zora::*, + mantle::*, morph::*, near::*, pokt::*, publicnode::*, quicknode::*, server::*, solscan::*, + unichain::*, zerion::*, zksync::*, zora::*, }; mod arbitrum; mod aurora; @@ -33,7 +33,9 @@ mod pokt; mod publicnode; mod quicknode; mod server; +pub mod solscan; mod unichain; +pub mod zerion; mod zksync; mod zora; @@ -87,6 +89,11 @@ pub trait ProviderConfig { fn provider_kind(&self) -> ProviderKind; } +pub trait BalanceProviderConfig { + fn supported_namespaces(self) -> HashMap; + fn provider_kind(&self) -> ProviderKind; +} + #[cfg(test)] #[cfg(not(feature = "test-mock-bundler"))] // These tests depend on environment variables mod test { @@ -277,7 +284,7 @@ mod test { infura_project_id: "INFURA_PROJECT_ID".to_string(), pokt_project_id: "POKT_PROJECT_ID".to_string(), quicknode_api_tokens: "QUICKNODE_API_TOKENS".to_string(), - zerion_api_key: Some("ZERION_API_KEY".to_owned()), + zerion_api_key: "ZERION_API_KEY".to_owned(), coinbase_api_key: Some("COINBASE_API_KEY".to_owned()), coinbase_app_id: Some("COINBASE_APP_ID".to_owned()), one_inch_api_key: Some("ONE_INCH_API_KEY".to_owned()), diff --git a/src/env/solscan.rs b/src/env/solscan.rs new file mode 100644 index 000000000..9f8b7e5cb --- /dev/null +++ b/src/env/solscan.rs @@ -0,0 +1,39 @@ +use { + super::BalanceProviderConfig, + crate::{ + providers::{Priority, Weight}, + utils::crypto::CaipNamespaces, + }, + std::collections::HashMap, +}; + +pub struct SolScanConfig { + pub api_key: String, + pub supported_namespaces: HashMap, +} + +impl SolScanConfig { + pub fn new(api_key: String) -> Self { + Self { + api_key, + supported_namespaces: default_supported_namespaces(), + } + } +} + +impl BalanceProviderConfig for SolScanConfig { + fn supported_namespaces(self) -> HashMap { + self.supported_namespaces + } + + fn provider_kind(&self) -> crate::providers::ProviderKind { + crate::providers::ProviderKind::SolScan + } +} + +fn default_supported_namespaces() -> HashMap { + HashMap::from([( + CaipNamespaces::Solana, + Weight::new(Priority::Normal).unwrap(), + )]) +} diff --git a/src/env/zerion.rs b/src/env/zerion.rs new file mode 100644 index 000000000..b8bf6ee52 --- /dev/null +++ b/src/env/zerion.rs @@ -0,0 +1,40 @@ +use { + super::BalanceProviderConfig, + crate::{ + providers::{Priority, Weight}, + utils::crypto::CaipNamespaces, + }, + std::collections::HashMap, +}; + +#[derive(Debug)] +pub struct ZerionConfig { + pub api_key: String, + pub supported_namespaces: HashMap, +} + +impl ZerionConfig { + pub fn new(api_key: String) -> Self { + Self { + api_key, + supported_namespaces: default_supported_namespaces(), + } + } +} + +impl BalanceProviderConfig for ZerionConfig { + fn supported_namespaces(self) -> HashMap { + self.supported_namespaces + } + + fn provider_kind(&self) -> crate::providers::ProviderKind { + crate::providers::ProviderKind::Zerion + } +} + +fn default_supported_namespaces() -> HashMap { + HashMap::from([( + CaipNamespaces::Eip155, + Weight::new(Priority::Normal).unwrap(), + )]) +} diff --git a/src/error.rs b/src/error.rs index 8a9bb150d..3e46fda9c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -75,6 +75,9 @@ pub enum RpcError { #[error("Failed to reach the balance provider")] BalanceProviderError, + #[error("Requested balance provider for the namespace is temporarily unavailable: {0}")] + BalanceTemporarilyUnavailable(String), + #[error("Failed to reach the fungible price provider: {0}")] FungiblePriceProviderError(String), @@ -282,6 +285,14 @@ impl IntoResponse for RpcError { )), ) .into_response(), + Self::BalanceTemporarilyUnavailable(namespace) => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(new_error_response( + "chainId".to_string(), + format!("Requested namespace {namespace} balance provider is temporarily unavailable"), + )), + ) + .into_response(), Self::InvalidChainIdFormat(chain_id) => ( StatusCode::BAD_REQUEST, Json(new_error_response( diff --git a/src/handlers/balance.rs b/src/handlers/balance.rs index c64455249..8cd335b8e 100644 --- a/src/handlers/balance.rs +++ b/src/handlers/balance.rs @@ -14,16 +14,14 @@ use { ethers::{abi::Address, types::H160}, hyper::HeaderMap, serde::{Deserialize, Serialize}, - std::{ - net::SocketAddr, - sync::Arc, - time::{Duration, SystemTime}, - }, + std::{net::SocketAddr, sync::Arc}, tap::TapFallible, tracing::log::{debug, error}, wc::future::FutureExt, }; +const PROVIDER_MAX_CALLS: usize = 2; + #[derive(Debug, Deserialize, Clone)] #[serde(rename_all = "camelCase")] pub struct BalanceQueryParams { @@ -108,20 +106,32 @@ async fn handler_internal( return Err(RpcError::InvalidAddress); } - let provider = state + let providers = state .providers - .balance_providers - .get(&namespace) - .ok_or_else(|| RpcError::UnsupportedNamespace(namespace))?; + .get_balance_provider_for_namespace(&namespace, PROVIDER_MAX_CALLS)?; - let start = SystemTime::now(); - let mut response = provider - .get_balance(address.clone(), query.clone().0, state.metrics.clone()) - .await - .tap_err(|e| { - error!("Failed to call balance with {}", e); - })?; - let latency = start.elapsed().unwrap_or(Duration::from_secs(0)); + let mut balance_response = None; + for provider in providers.iter() { + let provider_response = provider + .get_balance(address.clone(), query.clone().0, state.metrics.clone()) + .await + .tap_err(|e| { + error!("Failed to call balance with {}", e); + }); + + match provider_response { + Ok(response) => { + balance_response = Some(response); + break; + } + e => { + debug!("Balance provider returned an error {e:?}, trying the next provider"); + } + }; + } + let mut response = balance_response.ok_or(RpcError::BalanceTemporarilyUnavailable( + namespace.to_string(), + ))?; { let origin = headers @@ -137,7 +147,6 @@ async fn handler_internal( .unwrap_or((None, None, None)); for balance in &response.balances { state.analytics.balance_lookup(BalanceLookupInfo::new( - latency, balance.symbol.clone(), balance.chain_id.clone().unwrap_or_default(), balance.quantity.numeric.clone(), diff --git a/src/handlers/proxy.rs b/src/handlers/proxy.rs index 021c61a25..6de16b52a 100644 --- a/src/handlers/proxy.rs +++ b/src/handlers/proxy.rs @@ -73,7 +73,7 @@ pub async fn rpc_call( Some(provider_id) => { let provider = vec![state .providers - .get_provider_by_provider_id(&provider_id) + .get_rpc_provider_by_provider_id(&provider_id) .ok_or_else(|| RpcError::UnsupportedProvider(provider_id.clone()))?]; if let Some(ref testing_project_id) = state.config.server.testing_project_id { @@ -94,7 +94,7 @@ pub async fn rpc_call( } None => state .providers - .get_provider_for_chain_id(&chain_id, PROVIDER_PROXY_MAX_CALLS)?, + .get_rpc_provider_for_chain_id(&chain_id, PROVIDER_PROXY_MAX_CALLS)?, }; for (i, provider) in providers.iter().enumerate() { diff --git a/src/handlers/supported_chains.rs b/src/handlers/supported_chains.rs index 439f99b1a..20868caa8 100644 --- a/src/handlers/supported_chains.rs +++ b/src/handlers/supported_chains.rs @@ -16,5 +16,5 @@ pub async fn handler(state: State>) -> Result>, ) -> Result, RpcError> { - Ok(Json(state.providers.supported_chains.clone())) + Ok(Json(state.providers.rpc_supported_chains.clone())) } diff --git a/src/lib.rs b/src/lib.rs index 018396917..57516f9d2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,7 +20,8 @@ use { env::{ ArbitrumConfig, AuroraConfig, BaseConfig, BerachainConfig, BinanceConfig, GetBlockConfig, InfuraConfig, LavaConfig, MantleConfig, MorphConfig, NearConfig, PoktConfig, - PublicnodeConfig, QuicknodeConfig, UnichainConfig, ZKSyncConfig, ZoraConfig, + PublicnodeConfig, QuicknodeConfig, SolScanConfig, UnichainConfig, ZKSyncConfig, + ZerionConfig, ZoraConfig, }, error::RpcResult, http::Request, @@ -29,7 +30,8 @@ use { ArbitrumProvider, AuroraProvider, BaseProvider, BerachainProvider, BinanceProvider, GetBlockProvider, InfuraProvider, InfuraWsProvider, LavaProvider, MantleProvider, MorphProvider, NearProvider, PoktProvider, ProviderRepository, PublicnodeProvider, - QuicknodeProvider, UnichainProvider, ZKSyncProvider, ZoraProvider, ZoraWsProvider, + QuicknodeProvider, SolScanProvider, UnichainProvider, ZKSyncProvider, ZerionProvider, + ZoraProvider, ZoraWsProvider, }, sqlx::postgres::PgPoolOptions, std::{ @@ -44,7 +46,7 @@ use { trace::TraceLayer, ServiceBuilderExt, }, - tracing::{info, log::warn, Span}, + tracing::{error, info, log::warn, Span}, utils::rate_limit::RateLimit, wc::{ geoip::{ @@ -462,36 +464,56 @@ fn create_server( } fn init_providers(config: &ProvidersConfig) -> ProviderRepository { - let mut providers = ProviderRepository::new(config); + // Redis pool for providers responses caching where needed + let mut redis_pool = None; + if let Some(redis_addr) = &config.cache_redis_addr { + let redis_builder = deadpool_redis::Config::from_url(redis_addr) + .builder() + .map_err(|e| { + error!( + "Failed to create redis pool builder for provider's responses caching: {:?}", + e + ); + }) + .expect("Failed to create redis pool builder for provider's responses caching, builder is None"); + redis_pool = Some(Arc::new( + redis_builder + .runtime(deadpool_redis::Runtime::Tokio1) + .build() + .expect("Failed to create redis pool"), + )); + }; // Keep in-sync with SUPPORTED_CHAINS.md - providers.add_provider::(AuroraConfig::default()); - providers.add_provider::(ArbitrumConfig::default()); - providers - .add_provider::(PoktConfig::new(config.pokt_project_id.clone())); + let mut providers = ProviderRepository::new(config); + providers.add_rpc_provider::(AuroraConfig::default()); + providers.add_rpc_provider::(ArbitrumConfig::default()); + providers.add_rpc_provider::(PoktConfig::new( + config.pokt_project_id.clone(), + )); - providers.add_provider::(BaseConfig::default()); - providers.add_provider::(BinanceConfig::default()); - providers.add_provider::(ZKSyncConfig::default()); - providers.add_provider::(PublicnodeConfig::default()); - providers.add_provider::(QuicknodeConfig::new( + providers.add_rpc_provider::(BaseConfig::default()); + providers.add_rpc_provider::(BinanceConfig::default()); + providers.add_rpc_provider::(ZKSyncConfig::default()); + providers.add_rpc_provider::(PublicnodeConfig::default()); + providers.add_rpc_provider::(QuicknodeConfig::new( config.quicknode_api_tokens.clone(), )); - providers.add_provider::(InfuraConfig::new( + providers.add_rpc_provider::(InfuraConfig::new( config.infura_project_id.clone(), )); - providers.add_provider::(ZoraConfig::default()); - providers.add_provider::(NearConfig::default()); - providers.add_provider::(MantleConfig::default()); - providers.add_provider::(BerachainConfig::default()); - providers.add_provider::(UnichainConfig::default()); + providers.add_rpc_provider::(ZoraConfig::default()); + providers.add_rpc_provider::(NearConfig::default()); + providers.add_rpc_provider::(MantleConfig::default()); + providers.add_rpc_provider::(BerachainConfig::default()); + providers.add_rpc_provider::(UnichainConfig::default()); providers - .add_provider::(LavaConfig::new(config.lava_api_key.clone())); - providers.add_provider::(MorphConfig::default()); + .add_rpc_provider::(LavaConfig::new(config.lava_api_key.clone())); + providers.add_rpc_provider::(MorphConfig::default()); if let Some(getblock_access_tokens) = &config.getblock_access_tokens { - providers.add_provider::(GetBlockConfig::new( + providers.add_rpc_provider::(GetBlockConfig::new( getblock_access_tokens.clone(), )); }; @@ -501,6 +523,15 @@ fn init_providers(config: &ProvidersConfig) -> ProviderRepository { )); providers.add_ws_provider::(ZoraConfig::default()); + providers.add_balance_provider::( + ZerionConfig::new(config.zerion_api_key.clone()), + None, + ); + providers.add_balance_provider::( + SolScanConfig::new(config.solscan_api_v2_token.clone()), + redis_pool.clone(), + ); + providers } diff --git a/src/providers/mod.rs b/src/providers/mod.rs index f430539ea..58a16173a 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -1,7 +1,7 @@ use { - self::{coinbase::CoinbaseProvider, zerion::ZerionProvider}, + self::coinbase::CoinbaseProvider, crate::{ - env::ProviderConfig, + env::{BalanceProviderConfig, ProviderConfig}, error::{RpcError, RpcResult}, handlers::{ balance::{self, BalanceQueryParams, BalanceResponseBody}, @@ -32,6 +32,7 @@ use { async_trait::async_trait, axum::response::Response, axum_tungstenite::WebSocketUpgrade, + deadpool_redis::Pool, hyper::http::HeaderValue, mock_alto::{MockAltoProvider, MockAltoUrls}, rand::{distributions::WeightedIndex, prelude::Distribution, rngs::OsRng}, @@ -95,13 +96,15 @@ pub use { solscan::SolScanProvider, tenderly::TenderlyProvider, unichain::UnichainProvider, + zerion::ZerionProvider, zksync::ZKSyncProvider, zora::{ZoraProvider, ZoraWsProvider}, }; static WS_PROXY_TASK_METRICS: TaskMetrics = TaskMetrics::new("ws_proxy_task"); -pub type WeightResolver = HashMap>; +pub type ChainsWeightResolver = HashMap>; +pub type NamespacesWeightResolver = HashMap>; #[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize)] pub struct ProvidersConfig { @@ -115,7 +118,7 @@ pub struct ProvidersConfig { pub pokt_project_id: String, pub quicknode_api_tokens: String, - pub zerion_api_key: Option, + pub zerion_api_key: String, pub coinbase_api_key: Option, pub coinbase_app_id: Option, pub one_inch_api_key: Option, @@ -147,27 +150,30 @@ pub struct SupportedChains { } pub struct ProviderRepository { - pub supported_chains: SupportedChains, + pub rpc_supported_chains: SupportedChains, + rpc_providers: HashMap>, + rpc_weight_resolver: ChainsWeightResolver, - providers: HashMap>, ws_providers: HashMap>, + ws_weight_resolver: ChainsWeightResolver, - weight_resolver: WeightResolver, - ws_weight_resolver: WeightResolver, - - prometheus_client: prometheus_http_query::Client, - prometheus_workspace_header: String, + balance_supported_namespaces: HashSet, + balance_providers: HashMap>, + balance_weight_resolver: NamespacesWeightResolver, pub history_providers: HashMap>, pub portfolio_provider: Arc, pub coinbase_pay_provider: Arc, pub onramp_provider: Arc, - pub balance_providers: HashMap>, + pub conversion_provider: Arc, pub fungible_price_providers: HashMap>, pub bundler_ops_provider: Arc, pub chain_orchestrator_provider: Arc, pub simulation_provider: Arc, + + prometheus_client: prometheus_http_query::Client, + prometheus_workspace_header: String, } impl ProviderRepository { @@ -211,10 +217,7 @@ impl ProviderRepository { // Don't crash the application if the ZERION_API_KEY is not set // TODO: find a better way to handle this - let zerion_api_key = config - .zerion_api_key - .clone() - .unwrap_or("ZERION_KEY_UNDEFINED".into()); + let zerion_api_key = config.zerion_api_key.clone(); // Don't crash the application if the COINBASE_API_KEY_UNDEFINED is not set // TODO: find a better way to handle this @@ -287,21 +290,23 @@ impl ProviderRepository { )); Self { - supported_chains: SupportedChains { + rpc_supported_chains: SupportedChains { http: HashSet::new(), ws: HashSet::new(), }, - providers: HashMap::new(), + rpc_providers: HashMap::new(), + rpc_weight_resolver: HashMap::new(), ws_providers: HashMap::new(), - weight_resolver: HashMap::new(), ws_weight_resolver: HashMap::new(), + balance_supported_namespaces: HashSet::new(), + balance_providers: HashMap::new(), + balance_weight_resolver: HashMap::new(), prometheus_client, prometheus_workspace_header, history_providers, portfolio_provider, coinbase_pay_provider: coinbase_pay_provider.clone(), onramp_provider: coinbase_pay_provider, - balance_providers, conversion_provider: one_inch_provider.clone(), fungible_price_providers, bundler_ops_provider, @@ -311,12 +316,12 @@ impl ProviderRepository { } #[tracing::instrument(skip(self), level = "debug")] - pub fn get_provider_for_chain_id( + pub fn get_rpc_provider_for_chain_id( &self, chain_id: &str, max_providers: usize, ) -> Result>, RpcError> { - let Some(providers) = self.weight_resolver.get(chain_id) else { + let Some(providers) = self.rpc_weight_resolver.get(chain_id) else { return Err(RpcError::UnsupportedChain(chain_id.to_string())); }; @@ -357,7 +362,7 @@ impl ProviderRepository { } }; - self.providers.get(provider).cloned().ok_or_else(|| { + self.rpc_providers.get(provider).cloned().ok_or_else(|| { RpcError::WeightedProvidersIndex(format!( "Provider not found during the weighted index check: {}", provider @@ -376,6 +381,75 @@ impl ProviderRepository { } } + #[tracing::instrument(skip(self), level = "debug")] + pub fn get_balance_provider_for_namespace( + &self, + namespace: &CaipNamespaces, + max_providers: usize, + ) -> Result>, RpcError> { + let Some(providers) = self.balance_weight_resolver.get(namespace) else { + return Err(RpcError::UnsupportedChain(namespace.to_string())); + }; + + if providers.is_empty() { + return Err(RpcError::UnsupportedChain(namespace.to_string())); + } + + let weights: Vec<_> = providers + .iter() + .map(|(_, weight)| weight.value()) + .map(|w| w.min(1)) + .collect(); + let non_zero_weight_providers = weights.iter().filter(|&x| *x > 0).count(); + let keys = providers.keys().cloned().collect::>(); + + match WeightedIndex::new(weights) { + Ok(mut dist) => { + let providers_to_iterate = std::cmp::min(max_providers, non_zero_weight_providers); + let providers_result = (0..providers_to_iterate) + .map(|i| { + let dist_key = dist.sample(&mut OsRng); + let provider = keys.get(dist_key).ok_or_else(|| { + RpcError::WeightedProvidersIndex(format!( + "Failed to get random balanceprovider for namespace: {}", + namespace + )) + })?; + + // Update the weight of the provider to 0 to remove it from the next + // sampling, as updating weights returns an error if + // all weights are zero + if i < providers_to_iterate - 1 { + if let Err(e) = dist.update_weights(&[(dist_key, &0)]) { + return Err(RpcError::WeightedProvidersIndex(format!( + "Failed to update weight in sampling iteration: {}", + e + ))); + } + }; + + self.balance_providers + .get(provider) + .cloned() + .ok_or_else(|| { + RpcError::WeightedProvidersIndex(format!( + "Balance provider not found during the weighted index check: {}", + provider + )) + }) + }) + .collect::, _>>()?; + Ok(providers_result) + } + Err(e) => { + // Respond with temporarily unavailable when all weights are 0 for + // a chain providers + warn!("Failed to create weighted index: {}", e); + Err(RpcError::ChainTemporarilyUnavailable(namespace.to_string())) + } + } + } + #[tracing::instrument(skip(self), level = "debug")] pub fn get_ws_provider_for_chain_id(&self, chain_id: &str) -> Option> { let providers = self.ws_weight_resolver.get(chain_id)?; @@ -420,7 +494,7 @@ impl ProviderRepository { supported_ws_chains .into_iter() .for_each(|(chain_id, (_, weight))| { - self.supported_chains.ws.insert(chain_id.clone()); + self.rpc_supported_chains.ws.insert(chain_id.clone()); self.ws_weight_resolver .entry(chain_id) .or_default() @@ -428,14 +502,14 @@ impl ProviderRepository { }); } - pub fn add_provider + RpcProvider + 'static, C: ProviderConfig>( + pub fn add_rpc_provider + RpcProvider + 'static, C: ProviderConfig>( &mut self, provider_config: C, ) { let provider = T::new(&provider_config); let arc_provider = Arc::new(provider); - self.providers + self.rpc_providers .insert(provider_config.provider_kind(), arc_provider); let provider_kind = provider_config.provider_kind(); @@ -444,8 +518,8 @@ impl ProviderRepository { supported_chains .into_iter() .for_each(|(chain_id, (_, weight))| { - self.supported_chains.http.insert(chain_id.clone()); - self.weight_resolver + self.rpc_supported_chains.http.insert(chain_id.clone()); + self.rpc_weight_resolver .entry(chain_id) .or_default() .insert(provider_kind, weight); @@ -453,6 +527,35 @@ impl ProviderRepository { debug!("Added provider: {}", provider_kind); } + pub fn add_balance_provider< + T: BalanceProviderFactory + BalanceProvider + 'static, + C: BalanceProviderConfig, + >( + &mut self, + provider_config: C, + cache: Option>, + ) { + let provider = T::new(&provider_config, cache); + let arc_provider = Arc::new(provider); + + self.balance_providers + .insert(provider_config.provider_kind(), arc_provider); + + let provider_kind = provider_config.provider_kind(); + let supported_namespaces = provider_config.supported_namespaces(); + + supported_namespaces + .into_iter() + .for_each(|(namespace, weight)| { + self.balance_supported_namespaces.insert(namespace); + self.balance_weight_resolver + .entry(namespace) + .or_default() + .insert(provider_kind, weight); + }); + debug!("Balance provider added: {}", provider_kind); + } + #[tracing::instrument(skip_all, level = "debug")] pub async fn update_weights(&self, metrics: &crate::Metrics) { debug!("Updating weights"); @@ -474,8 +577,8 @@ impl ProviderRepository { { Ok(data) => { let parsed_weights = weights::parse_weights(data); - weights::update_values(&self.weight_resolver, parsed_weights); - weights::record_values(&self.weight_resolver, metrics); + weights::update_values(&self.rpc_weight_resolver, parsed_weights); + weights::record_values(&self.rpc_weight_resolver, metrics); } Err(e) => { warn!("Failed to update weights from prometheus: {}", e); @@ -484,10 +587,13 @@ impl ProviderRepository { } #[tracing::instrument(skip(self), level = "debug")] - pub fn get_provider_by_provider_id(&self, provider_id: &str) -> Option> { + pub fn get_rpc_provider_by_provider_id( + &self, + provider_id: &str, + ) -> Option> { let provider = ProviderKind::from_str(provider_id)?; - self.providers.get(&provider).cloned() + self.rpc_providers.get(&provider).cloned() } } @@ -744,6 +850,10 @@ pub trait BalanceProvider: Send + Sync { ) -> RpcResult; } +pub trait BalanceProviderFactory: BalanceProvider { + fn new(provider_config: &T, cache: Option>) -> Self; +} + #[async_trait] pub trait FungiblePriceProvider: Send + Sync { async fn get_price( diff --git a/src/providers/solscan.rs b/src/providers/solscan.rs index 71b22b1fd..c9836b5b0 100644 --- a/src/providers/solscan.rs +++ b/src/providers/solscan.rs @@ -4,6 +4,7 @@ use { SupportedCurrencies, }, crate::{ + env::SolScanConfig, error::{RpcError, RpcResult}, handlers::{ balance::{BalanceItem, BalanceQuantity, BalanceQueryParams, BalanceResponseBody}, @@ -15,7 +16,7 @@ use { HistoryTransactionURLItem, }, }, - providers::ProviderKind, + providers::{BalanceProviderFactory, ProviderKind}, storage::error::StorageError, utils::crypto::{CaipNamespaces, SOLANA_NATIVE_TOKEN_ADDRESS}, Metrics, @@ -477,6 +478,17 @@ impl BalanceProvider for SolScanProvider { } } +impl BalanceProviderFactory for SolScanProvider { + fn new(provider_config: &SolScanConfig, cache: Option>) -> Self { + Self { + provider_kind: ProviderKind::SolScan, + api_v2_token: provider_config.api_key.clone(), + http_client: reqwest::Client::new(), + redis_caching_pool: cache, + } + } +} + #[async_trait] impl HistoryProvider for SolScanProvider { #[tracing::instrument(skip(self, params), fields(provider = "SolScan"), level = "debug")] diff --git a/src/providers/weights.rs b/src/providers/weights.rs index 731893ea0..4190852dc 100644 --- a/src/providers/weights.rs +++ b/src/providers/weights.rs @@ -1,5 +1,5 @@ use { - super::{ProviderKind, WeightResolver}, + super::{ChainsWeightResolver, ProviderKind}, crate::env::ChainId, prometheus_http_query::response::PromqlResult, std::collections::HashMap, @@ -123,7 +123,7 @@ fn calculate_chain_weight( } #[tracing::instrument(skip_all, level = "debug")] -pub fn update_values(weight_resolver: &WeightResolver, parsed_weights: ParsedWeights) { +pub fn update_values(weight_resolver: &ChainsWeightResolver, parsed_weights: ParsedWeights) { for (provider, (chain_availabilities, provider_availability)) in parsed_weights { for (chain_id, chain_availability) in chain_availabilities { let chain_id = chain_id.0; @@ -150,7 +150,7 @@ pub fn update_values(weight_resolver: &WeightResolver, parsed_weights: ParsedWei } } -pub fn record_values(weight_resolver: &WeightResolver, metrics: &crate::Metrics) { +pub fn record_values(weight_resolver: &ChainsWeightResolver, metrics: &crate::Metrics) { for (chain_id, provider_chain_weight) in weight_resolver { for (provider_kind, weight) in provider_chain_weight { let weight = weight.value(); diff --git a/src/providers/zerion.rs b/src/providers/zerion.rs index eb661fe4d..0145d3ccf 100644 --- a/src/providers/zerion.rs +++ b/src/providers/zerion.rs @@ -1,6 +1,7 @@ use { - super::{BalanceProvider, HistoryProvider, PortfolioProvider}, + super::{BalanceProvider, BalanceProviderFactory, HistoryProvider, PortfolioProvider}, crate::{ + env::ZerionConfig, error::{RpcError, RpcResult}, handlers::{ balance::{BalanceQueryParams, BalanceResponseBody}, @@ -22,6 +23,7 @@ use { Metrics, }, async_trait::async_trait, + deadpool_redis::Pool, serde::{Deserialize, Serialize}, std::{sync::Arc, time::SystemTime}, tap::TapFallible, @@ -528,3 +530,14 @@ impl BalanceProvider for ZerionProvider { Ok(response) } } + +impl BalanceProviderFactory for ZerionProvider { + fn new(provider_config: &ZerionConfig, _cache: Option>) -> Self { + let http_client = reqwest::Client::new(); + Self { + provider_kind: ProviderKind::Zerion, + api_key: provider_config.api_key.clone(), + http_client, + } + } +}