From 43e726f16844a29fd8e5596da38e5f1e5942aa86 Mon Sep 17 00:00:00 2001 From: Stanislas Polu Date: Thu, 4 Apr 2024 10:26:18 +0200 Subject: [PATCH] Dust Apps: Improve resilience (#4552) * Always retry model errors at least once * bump data source retrieval max topk to 1024 * Bump parallelism of model execution in dust apps maps --- core/src/data_sources/data_source.rs | 2 +- core/src/providers/ai21.rs | 10 ++++-- core/src/providers/anthropic.rs | 20 +++++++++--- core/src/providers/cohere.rs | 34 +++++++++++++------- core/src/providers/mistral.rs | 18 ++++++++--- core/src/providers/openai.rs | 46 ++++++++++++++++++++-------- core/src/run.rs | 4 +-- 7 files changed, 95 insertions(+), 39 deletions(-) diff --git a/core/src/data_sources/data_source.rs b/core/src/data_sources/data_source.rs index e26a5f734dd6..095cee005658 100644 --- a/core/src/data_sources/data_source.rs +++ b/core/src/data_sources/data_source.rs @@ -1310,7 +1310,7 @@ impl DataSource { Ok(document) } - const MAX_TOP_K_SEARCH: usize = 128; + const MAX_TOP_K_SEARCH: usize = 1024; pub async fn search( &self, diff --git a/core/src/providers/ai21.rs b/core/src/providers/ai21.rs index a7639e645757..a5a7f180d56f 100644 --- a/core/src/providers/ai21.rs +++ b/core/src/providers/ai21.rs @@ -131,9 +131,9 @@ impl AI21LLM { Err(ModelError { message: format!("Ai21APIError: {}", error.detail), retryable: Some(ModelErrorRetryOptions { - sleep: Duration::from_millis(2000), + sleep: Duration::from_millis(500), factor: 2, - retries: 8, + retries: 1, }), }) } @@ -141,7 +141,11 @@ impl AI21LLM { let error: Error = serde_json::from_slice(c)?; Err(ModelError { message: format!("Ai21APIError: {}", error.detail), - retryable: None, + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(500), + factor: 1, + retries: 1, + }), }) } }?; diff --git a/core/src/providers/anthropic.rs b/core/src/providers/anthropic.rs index 15ca097683b6..d501bafff383 100644 --- a/core/src/providers/anthropic.rs +++ b/core/src/providers/anthropic.rs @@ -2,7 +2,7 @@ use crate::providers::embedder::{Embedder, EmbedderVector}; use crate::providers::llm::{ ChatMessage, ChatMessageRole, LLMChatGeneration, LLMGeneration, Tokens, LLM, }; -use crate::providers::provider::{ModelError, Provider, ProviderID}; +use crate::providers::provider::{ModelError, ModelErrorRetryOptions, Provider, ProviderID}; use crate::providers::tiktoken::tiktoken::anthropic_base_singleton; use crate::run::Credentials; use crate::utils; @@ -280,7 +280,11 @@ impl AnthropicLLM { let error: Error = serde_json::from_slice(c)?; Err(ModelError { message: format!("Anthropic API Error: {}", error.to_string()), - retryable: None, + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(500), + factor: 1, + retries: 1, + }), }) } }?; @@ -537,7 +541,11 @@ impl AnthropicLLM { "Anthropic API Error: {}", event.error.to_string() ), - retryable: None, + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(500), + factor: 1, + retries: 1, + }), })?; break 'stream; } @@ -750,7 +758,11 @@ impl AnthropicLLM { let error: Error = serde_json::from_slice(c)?; Err(ModelError { message: format!("Anthropic API Error: {}", error.to_string()), - retryable: None, + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(500), + factor: 1, + retries: 1, + }), }) } }?; diff --git a/core/src/providers/cohere.rs b/core/src/providers/cohere.rs index 49a833f43854..b26e01716bf2 100644 --- a/core/src/providers/cohere.rs +++ b/core/src/providers/cohere.rs @@ -58,9 +58,9 @@ async fn api_encode(api_key: &str, text: &str) -> Result> { Err(ModelError { message: format!("CohereAPIError: {}", error.message), retryable: Some(ModelErrorRetryOptions { - sleep: Duration::from_millis(2000), + sleep: Duration::from_millis(500), factor: 2, - retries: 8, + retries: 3, }), }) } @@ -68,7 +68,11 @@ async fn api_encode(api_key: &str, text: &str) -> Result> { let error: Error = serde_json::from_slice(c)?; Err(ModelError { message: format!("CohereAPIError: {}", error.message), - retryable: None, + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(500), + factor: 1, + retries: 1, + }), }) } }?; @@ -105,9 +109,9 @@ async fn api_decode(api_key: &str, tokens: Vec) -> Result { Err(ModelError { message: format!("CohereAPIError: {}", error.message), retryable: Some(ModelErrorRetryOptions { - sleep: Duration::from_millis(2000), + sleep: Duration::from_millis(500), factor: 2, - retries: 8, + retries: 3, }), }) } @@ -115,7 +119,11 @@ async fn api_decode(api_key: &str, tokens: Vec) -> Result { let error: Error = serde_json::from_slice(c)?; Err(ModelError { message: format!("CohereAPIError: {}", error.message), - retryable: None, + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(500), + factor: 1, + retries: 1, + }), }) } }?; @@ -226,9 +234,9 @@ impl CohereLLM { Err(ModelError { message: format!("CohereAPIError: {}", error.message), retryable: Some(ModelErrorRetryOptions { - sleep: Duration::from_millis(2000), + sleep: Duration::from_millis(500), factor: 2, - retries: 8, + retries: 3, }), }) } @@ -450,9 +458,9 @@ impl CohereEmbedder { Err(ModelError { message: format!("CohereAPIError: {}", error.message), retryable: Some(ModelErrorRetryOptions { - sleep: Duration::from_millis(2000), + sleep: Duration::from_millis(500), factor: 2, - retries: 8, + retries: 3, }), }) } @@ -460,7 +468,11 @@ impl CohereEmbedder { let error: Error = serde_json::from_slice(c)?; Err(ModelError { message: format!("CohereAPIError: {}", error.message), - retryable: None, + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(500), + factor: 1, + retries: 1, + }), }) } }?; diff --git a/core/src/providers/mistral.rs b/core/src/providers/mistral.rs index e5a0b89726b6..1773333a64fb 100644 --- a/core/src/providers/mistral.rs +++ b/core/src/providers/mistral.rs @@ -490,14 +490,18 @@ impl MistralAILLM { true => Err(ModelError { message: error.message(), retryable: Some(ModelErrorRetryOptions { - sleep: Duration::from_millis(100), + sleep: Duration::from_millis(500), factor: 2, retries: 3, }), })?, false => Err(ModelError { message: error.message(), - retryable: None, + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(500), + factor: 1, + retries: 1, + }), })?, } break 'stream; @@ -710,14 +714,18 @@ impl MistralAILLM { true => Err(ModelError { message: error.message(), retryable: Some(ModelErrorRetryOptions { - sleep: Duration::from_millis(2000), + sleep: Duration::from_millis(500), factor: 2, - retries: 8, + retries: 3, }), }), false => Err(ModelError { message: error.message(), - retryable: None, + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(500), + factor: 1, + retries: 1, + }), }), } } diff --git a/core/src/providers/openai.rs b/core/src/providers/openai.rs index 971c2b1e0710..4235c4bf738f 100644 --- a/core/src/providers/openai.rs +++ b/core/src/providers/openai.rs @@ -273,14 +273,18 @@ pub async fn streamed_completion( true => Err(ModelError { message: error.message(), retryable: Some(ModelErrorRetryOptions { - sleep: Duration::from_millis(100), + sleep: Duration::from_millis(500), factor: 2, retries: 3, }), })?, false => Err(ModelError { message: error.message(), - retryable: None, + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(500), + factor: 1, + retries: 1, + }), })?, } break 'stream; @@ -500,14 +504,18 @@ pub async fn completion( true => Err(ModelError { message: error.message(), retryable: Some(ModelErrorRetryOptions { - sleep: Duration::from_millis(2000), + sleep: Duration::from_millis(500), factor: 2, - retries: 8, + retries: 3, }), }), false => Err(ModelError { message: error.message(), - retryable: None, + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(500), + factor: 1, + retries: 1, + }), }), } } @@ -651,14 +659,18 @@ pub async fn streamed_chat_completion( true => Err(ModelError { message: error.message(), retryable: Some(ModelErrorRetryOptions { - sleep: Duration::from_millis(100), + sleep: Duration::from_millis(500), factor: 2, retries: 3, }), })?, false => Err(ModelError { message: error.message(), - retryable: None, + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(500), + factor: 1, + retries: 1, + }), })?, } break 'stream; @@ -982,14 +994,18 @@ pub async fn chat_completion( true => Err(ModelError { message: error.message(), retryable: Some(ModelErrorRetryOptions { - sleep: Duration::from_millis(2000), + sleep: Duration::from_millis(500), factor: 2, - retries: 8, + retries: 3, }), }), false => Err(ModelError { message: error.message(), - retryable: None, + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(500), + factor: 1, + retries: 1, + }), }), } } @@ -1078,14 +1094,18 @@ pub async fn embed( true => Err(ModelError { message: error.message(), retryable: Some(ModelErrorRetryOptions { - sleep: Duration::from_millis(2000), + sleep: Duration::from_millis(500), factor: 2, - retries: 8, + retries: 3, }), }), false => Err(ModelError { message: error.message(), - retryable: None, + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(500), + factor: 1, + retries: 1, + }), }), } } diff --git a/core/src/run.rs b/core/src/run.rs index 12bdb14a1d21..f44695ea061c 100644 --- a/core/src/run.rs +++ b/core/src/run.rs @@ -47,8 +47,8 @@ impl RunConfig { BlockType::Data => 64, BlockType::DataSource => 8, BlockType::Code => 64, - BlockType::LLM => 8, - BlockType::Chat => 8, + BlockType::LLM => 32, + BlockType::Chat => 32, BlockType::Map => 64, BlockType::Reduce => 64, BlockType::Search => 8,