diff --git a/connectors/src/connectors/slack/bot.ts b/connectors/src/connectors/slack/bot.ts index 70370487fca2..2cd5f8a04f88 100644 --- a/connectors/src/connectors/slack/bot.ts +++ b/connectors/src/connectors/slack/bot.ts @@ -1,7 +1,9 @@ import { + AgentActionType, AgentGenerationSuccessEvent, AgentMessageType, DustAPI, + RetrievalDocumentType, } from "@connectors/lib/dust_api"; import { Connector, @@ -161,7 +163,7 @@ async function botAnswerMessage( }); const mainMessage = await slackClient.chat.postMessage({ channel: slackChannel, - text: "_I am thinking..._", + text: "_Thinking..._", thread_ts: slackMessageTs, mrkdwn: true, }); @@ -238,6 +240,7 @@ async function botAnswerMessage( } let fullAnswer = ""; + let action: AgentActionType | null = null; let lastSentDate = new Date(); for await (const event of streamRes.value.eventStream) { switch (event.type) { @@ -258,24 +261,31 @@ async function botAnswerMessage( } case "generation_tokens": { fullAnswer += event.text; - if (lastSentDate.getTime() + 1000 > new Date().getTime()) { + if (lastSentDate.getTime() + 1500 > new Date().getTime()) { continue; } lastSentDate = new Date(); + + let finalAnswer = _processCiteMention(fullAnswer, action); + finalAnswer += `...\n\n <${DUST_API}/w/${connector.workspaceId}/assistant/${conversation.sId}|Continue this conversation on Dust>`; + await slackClient.chat.update({ channel: slackChannel, - text: fullAnswer, + text: finalAnswer, ts: mainMessage.ts as string, thread_ts: slackMessageTs, }); break; } + case "agent_action_success": { + action = event.action; + break; + } case "agent_generation_success": { - const finalAnswer = `${_removeCiteMention( - event.text - )}\n\n <${DUST_API}/w/${connector.workspaceId}/assistant/${ - conversation.sId - }|Continue this conversation on Dust>`; + fullAnswer = event.text; + + let finalAnswer = _processCiteMention(fullAnswer, action); + finalAnswer += `\n\n <${DUST_API}/w/${connector.workspaceId}/assistant/${conversation.sId}|Continue this conversation on Dust>`; await slackClient.chat.update({ channel: slackChannel, @@ -293,11 +303,45 @@ async function botAnswerMessage( return new Err(new Error("Failed to get the final answer from Dust")); } -/* - * Temp > until I have a PR to properly handle mentions - */ -function _removeCiteMention(message: string) { - const regex = /:cite\[[a-zA-Z0-9,]+\]/g; +function _processCiteMention( + content: string, + action: AgentActionType | null +): string { + const references: { [key: string]: RetrievalDocumentType } = {}; + + if (action && action.type === "retrieval_action" && action.documents) { + action.documents.forEach((d) => { + references[d.reference] = d; + }); + } + + if (references) { + let counter = 0; + const refCounter: { [key: string]: number } = {}; + return content.replace(/:cite\[[a-zA-Z0-9, ]+\]/g, (match) => { + const keys = match.slice(6, -1).split(","); // slice off ":cite[" and "]" then split by comma + return keys + .map((key) => { + const k = key.trim(); + const ref = references[k]; + if (ref && ref.sourceUrl) { + if (!refCounter[k]) { + counter++; + refCounter[k] = counter; + } + return `[<${ref.sourceUrl}|${refCounter[k]}>]`; + } + return ""; + }) + .join(""); + }); + } + + return _removeCiteMention(content); +} + +function _removeCiteMention(message: string): string { + const regex = /:cite\[[a-zA-Z0-9, ]+\]/g; return message.replace(regex, ""); } diff --git a/connectors/src/lib/dust_api.ts b/connectors/src/lib/dust_api.ts index 9e28bb83ec2d..105367a1528a 100644 --- a/connectors/src/lib/dust_api.ts +++ b/connectors/src/lib/dust_api.ts @@ -280,7 +280,6 @@ export type AgentMessageType = { visibility: MessageVisibility; version: number; parentMessageId: string | null; - // configuration: AgentConfigurationType; status: AgentMessageStatus; action: AgentActionType | null; diff --git a/core/bin/dust_api.rs b/core/bin/dust_api.rs index 5ebff27ea641..c1c5c06819ce 100644 --- a/core/bin/dust_api.rs +++ b/core/bin/dust_api.rs @@ -505,7 +505,7 @@ async fn run_helper( None => Err(error_response( StatusCode::BAD_REQUEST, "missing_specification_error", - "No specification provided, either `specification` or + "No specification provided, either `specification` or `specification_hash` must be provided", None, ))?, @@ -1610,8 +1610,8 @@ struct TokenizePayload { async fn tokenize( extract::Json(payload): extract::Json, ) -> (StatusCode, Json) { - let embedder = provider(payload.provider_id).embedder(payload.model_id); - match embedder.tokenize(payload.text).await { + let embedder = provider(payload.provider_id).llm(payload.model_id); + match embedder.tokenize(&payload.text).await { Err(e) => error_response( StatusCode::INTERNAL_SERVER_ERROR, "internal_server_error", diff --git a/core/src/lib.rs b/core/src/lib.rs index b91e0705dad9..0a49058d4f72 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -30,6 +30,7 @@ pub mod providers { pub mod tiktoken; } pub mod anthropic; + pub mod textsynth; } pub mod http { pub mod request; diff --git a/core/src/providers/ai21.rs b/core/src/providers/ai21.rs index 0e157394705f..d52a995de754 100644 --- a/core/src/providers/ai21.rs +++ b/core/src/providers/ai21.rs @@ -190,6 +190,10 @@ impl LLM for AI21LLM { Err(anyhow!("Encode/Decode not implemented for provider `ai21`")) } + async fn tokenize(&self, _text: &str) -> Result> { + Err(anyhow!("Tokenize not implemented for provider `ai21`")) + } + async fn generate( &self, prompt: &str, @@ -385,10 +389,6 @@ impl Embedder for AI21Embedder { Err(anyhow!("Encode/Decode not implemented for provider `ai21`")) } - async fn tokenize(&self, _text: String) -> Result> { - Err(anyhow!("Tokenize not implemented for provider `ai21`")) - } - async fn embed(&self, _text: Vec<&str>, _extras: Option) -> Result> { Err(anyhow!("Embeddings not available for provider `ai21`")) } diff --git a/core/src/providers/anthropic.rs b/core/src/providers/anthropic.rs index 723733223b72..04af3ec2ab13 100644 --- a/core/src/providers/anthropic.rs +++ b/core/src/providers/anthropic.rs @@ -572,6 +572,10 @@ impl LLM for AnthropicLLM { decode_async(anthropic_base_singleton(), tokens).await } + async fn tokenize(&self, text: &str) -> Result> { + tokenize_async(anthropic_base_singleton(), text).await + } + async fn chat( &self, messages: &Vec, @@ -665,10 +669,6 @@ impl Embedder for AnthropicEmbedder { decode_async(anthropic_base_singleton(), tokens).await } - async fn tokenize(&self, text: String) -> Result> { - tokenize_async(anthropic_base_singleton(), text).await - } - async fn embed(&self, _text: Vec<&str>, _extras: Option) -> Result> { Err(anyhow!("Embeddings not available for provider `anthropic`")) } diff --git a/core/src/providers/azure_openai.rs b/core/src/providers/azure_openai.rs index 396b30ff80d2..ea17601f79cb 100644 --- a/core/src/providers/azure_openai.rs +++ b/core/src/providers/azure_openai.rs @@ -1,3 +1,4 @@ +use super::tiktoken::tiktoken::{decode_async, encode_async, tokenize_async}; use crate::providers::embedder::{Embedder, EmbedderVector}; use crate::providers::llm::Tokens; use crate::providers::llm::{ChatMessage, LLMChatGeneration, LLMGeneration, LLM}; @@ -265,13 +266,15 @@ impl LLM for AzureOpenAILLM { } async fn encode(&self, text: &str) -> Result> { - let tokens = { self.tokenizer().lock().encode_with_special_tokens(text) }; - Ok(tokens) + encode_async(self.tokenizer(), text).await } async fn decode(&self, tokens: Vec) -> Result { - let str = { self.tokenizer().lock().decode(tokens)? }; - Ok(str) + decode_async(self.tokenizer(), tokens).await + } + + async fn tokenize(&self, text: &str) -> Result> { + tokenize_async(self.tokenizer(), text).await } async fn generate( @@ -600,10 +603,6 @@ impl Embedder for AzureOpenAIEmbedder { Ok(str) } - async fn tokenize(&self, _text: String) -> Result> { - Err(anyhow!("Tokenize not implemented for provider `anthropic`")) - } - async fn embed(&self, text: Vec<&str>, extras: Option) -> Result> { let e = embed( self.uri()?, diff --git a/core/src/providers/cohere.rs b/core/src/providers/cohere.rs index 19a73ca0f992..2e8b5d92bec9 100644 --- a/core/src/providers/cohere.rs +++ b/core/src/providers/cohere.rs @@ -297,6 +297,16 @@ impl LLM for CohereLLM { api_decode(self.api_key.as_ref().unwrap(), tokens).await } + // We return empty string in tokenize to partially support the endpoint. + async fn tokenize(&self, text: &str) -> Result> { + assert!(self.api_key.is_some()); + let tokens = api_encode(self.api_key.as_ref().unwrap(), text).await?; + Ok(tokens + .iter() + .map(|t| (*t, "".to_string())) + .collect::>()) + } + async fn generate( &self, prompt: &str, @@ -534,10 +544,6 @@ impl Embedder for CohereEmbedder { api_decode(self.api_key.as_ref().unwrap(), tokens).await } - async fn tokenize(&self, _text: String) -> Result> { - Err(anyhow!("Tokenize not implemented for provider `Cohere`")) - } - async fn embed(&self, text: Vec<&str>, _extras: Option) -> Result> { assert!(self.api_key.is_some()); diff --git a/core/src/providers/embedder.rs b/core/src/providers/embedder.rs index 86b0f69a4fbc..1257fb48f894 100644 --- a/core/src/providers/embedder.rs +++ b/core/src/providers/embedder.rs @@ -26,8 +26,6 @@ pub trait Embedder { async fn encode(&self, text: &str) -> Result>; async fn decode(&self, tokens: Vec) -> Result; - async fn tokenize(&self, text: String) -> Result>; - async fn embed(&self, text: Vec<&str>, extras: Option) -> Result>; } diff --git a/core/src/providers/llm.rs b/core/src/providers/llm.rs index 1a372b78d231..46b7e1d865f8 100644 --- a/core/src/providers/llm.rs +++ b/core/src/providers/llm.rs @@ -103,6 +103,7 @@ pub trait LLM { async fn encode(&self, text: &str) -> Result>; async fn decode(&self, tokens: Vec) -> Result; + async fn tokenize(&self, text: &str) -> Result>; async fn generate( &self, diff --git a/core/src/providers/openai.rs b/core/src/providers/openai.rs index 19542723e52f..7f6eaca59a66 100644 --- a/core/src/providers/openai.rs +++ b/core/src/providers/openai.rs @@ -1160,6 +1160,10 @@ impl LLM for OpenAILLM { decode_async(self.tokenizer(), tokens).await } + async fn tokenize(&self, text: &str) -> Result> { + tokenize_async(self.tokenizer(), text).await + } + async fn generate( &self, prompt: &str, @@ -1575,10 +1579,6 @@ impl Embedder for OpenAIEmbedder { decode_async(self.tokenizer(), tokens).await } - async fn tokenize(&self, text: String) -> Result> { - tokenize_async(self.tokenizer(), text).await - } - async fn embed(&self, text: Vec<&str>, extras: Option) -> Result> { let e = embed( self.uri()?, diff --git a/core/src/providers/provider.rs b/core/src/providers/provider.rs index a6f7ca78a321..0ad36b27b51c 100644 --- a/core/src/providers/provider.rs +++ b/core/src/providers/provider.rs @@ -13,6 +13,8 @@ use serde::{Deserialize, Serialize}; use std::str::FromStr; use std::time::Duration; +use super::textsynth::TextSynthProvider; + #[derive(Debug, Clone, Copy, Serialize, PartialEq, Deserialize)] #[serde(rename_all = "lowercase")] pub enum ProviderID { @@ -22,6 +24,7 @@ pub enum ProviderID { #[serde(rename = "azure_openai")] AzureOpenAI, Anthropic, + TextSynth, } impl ToString for ProviderID { @@ -32,6 +35,7 @@ impl ToString for ProviderID { ProviderID::AI21 => String::from("ai21"), ProviderID::AzureOpenAI => String::from("azure_openai"), ProviderID::Anthropic => String::from("anthropic"), + ProviderID::TextSynth => String::from("textsynth"), } } } @@ -45,6 +49,7 @@ impl FromStr for ProviderID { "ai21" => Ok(ProviderID::AI21), "azure_openai" => Ok(ProviderID::AzureOpenAI), "anthropic" => Ok(ProviderID::Anthropic), + "textsynth" => Ok(ProviderID::TextSynth), _ => Err(ParseError::with_message( "Unknown provider ID (possible values: openai, cohere, ai21, azure_openai)", ))?, @@ -139,5 +144,6 @@ pub fn provider(t: ProviderID) -> Box { ProviderID::AI21 => Box::new(AI21Provider::new()), ProviderID::AzureOpenAI => Box::new(AzureOpenAIProvider::new()), ProviderID::Anthropic => Box::new(AnthropicProvider::new()), + ProviderID::TextSynth => Box::new(TextSynthProvider::new()), } } diff --git a/core/src/providers/textsynth.rs b/core/src/providers/textsynth.rs new file mode 100644 index 000000000000..72865c925379 --- /dev/null +++ b/core/src/providers/textsynth.rs @@ -0,0 +1,482 @@ +use crate::providers::embedder::Embedder; +use crate::providers::llm::Tokens; +use crate::providers::llm::{ChatMessage, LLMChatGeneration, LLMGeneration, LLM}; +use crate::providers::provider::{ModelError, ModelErrorRetryOptions, Provider, ProviderID}; +use crate::run::Credentials; +use crate::utils; +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use hyper::{body::Buf, Body, Client, Method, Request, Uri}; +use hyper_tls::HttpsConnector; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use std::io::prelude::*; +use std::time::Duration; +use tokio::sync::mpsc::UnboundedSender; + +use super::embedder::EmbedderVector; +use super::llm::ChatFunction; + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Error { + pub error: String, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct TokenizeResponse { + pub tokens: Vec, +} + +async fn api_tokenize(api_key: &str, engine: &str, text: &str) -> Result> { + let https = HttpsConnector::new(); + let cli = Client::builder().build::<_, hyper::Body>(https); + + let req = Request::builder() + .method(Method::POST) + .uri(format!("https://api.textsynth.com/v1/engines/{}/tokenize", engine).parse::()?) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body(Body::from( + json!({ + "text": text, + }) + .to_string(), + ))?; + + let res = cli.request(req).await?; + let status = res.status(); + let body = hyper::body::aggregate(res).await?; + let mut b: Vec = vec![]; + body.reader().read_to_end(&mut b)?; + let c: &[u8] = &b; + + let r = match status { + hyper::StatusCode::OK => { + let r: TokenizeResponse = serde_json::from_slice(c)?; + Ok(r) + } + hyper::StatusCode::TOO_MANY_REQUESTS => { + let error: Error = serde_json::from_slice(c).unwrap_or(Error { + error: "Too many requests".to_string(), + }); + Err(ModelError { + message: format!("TextSynthAPIError: {}", error.error), + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(2000), + factor: 2, + retries: 8, + }), + }) + } + hyper::StatusCode::BAD_REQUEST => { + let error: Error = serde_json::from_slice(c).unwrap_or(Error { + error: "Unknown error".to_string(), + }); + Err(ModelError { + message: format!("TextSynthAPIError: {}", error.error), + retryable: None, + }) + } + _ => { + let error: Error = serde_json::from_slice(c)?; + Err(ModelError { + message: format!("TextSynthAPIError: {}", error.error), + retryable: None, + }) + } + }?; + Ok(r.tokens) +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Completion { + pub text: String, + pub reached_end: bool, + pub input_tokens: usize, + pub output_tokens: usize, +} + +pub struct TextSynthLLM { + id: String, + api_key: Option, +} + +impl TextSynthLLM { + pub fn new(id: String) -> Self { + TextSynthLLM { id, api_key: None } + } + + fn uri(&self) -> Result { + Ok(format!( + "https://api.textsynth.com/v1/engines/{}/completions", + self.id + ) + .parse::()?) + } + + // fn chat_uri(&self) -> Result { + // Ok(format!("https://api.textsynth.com/v1/engines/{}/chat", self.id).parse::()?) + // } + + async fn completion( + &self, + prompt: &str, + max_tokens: Option, + temperature: f32, + stop: &Vec, + top_k: Option, + top_p: Option, + frequency_penalty: Option, + presence_penalty: Option, + repetition_penalty: Option, + typical_p: Option, + ) -> Result { + assert!(self.api_key.is_some()); + + let https = HttpsConnector::new(); + let cli = Client::builder().build::<_, hyper::Body>(https); + + let mut body = json!({ + "prompt": prompt, + "temperature": temperature, + }); + if max_tokens.is_some() { + body["max_tokens"] = json!(max_tokens.unwrap()); + } + if stop.len() > 0 { + body["stop"] = json!(stop); + } + if top_k.is_some() { + body["top_k"] = json!(top_k.unwrap()); + } + if top_p.is_some() { + body["top_p"] = json!(top_k.unwrap()); + } + if frequency_penalty.is_some() { + body["frequency_penalty"] = json!(frequency_penalty.unwrap()); + } + if presence_penalty.is_some() { + body["presence_penalty"] = json!(presence_penalty.unwrap()); + } + if repetition_penalty.is_some() { + body["repetition_penalty"] = json!(repetition_penalty.unwrap()); + } + if typical_p.is_some() { + body["typical_p"] = json!(typical_p.unwrap()); + } + + let req = Request::builder() + .method(Method::POST) + .uri(self.uri()?) + .header("Content-Type", "application/json") + .header( + "Authorization", + format!("Bearer {}", self.api_key.clone().unwrap()), + ) + .body(Body::from(body.to_string()))?; + + let res = cli.request(req).await?; + let status = res.status(); + let body = hyper::body::aggregate(res).await?; + let mut b: Vec = vec![]; + body.reader().read_to_end(&mut b)?; + let c: &[u8] = &b; + + let response = match status { + hyper::StatusCode::OK => { + let completion: Completion = serde_json::from_slice(c)?; + Ok(completion) + } + hyper::StatusCode::TOO_MANY_REQUESTS => { + let error: Error = serde_json::from_slice(c).unwrap_or(Error { + error: "Too many requests".to_string(), + }); + Err(ModelError { + message: format!("TextSynthAPIError: {}", error.error), + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(2000), + factor: 2, + retries: 8, + }), + }) + } + hyper::StatusCode::BAD_REQUEST => { + let error: Error = serde_json::from_slice(c).unwrap_or(Error { + error: "Unknown error".to_string(), + }); + Err(ModelError { + message: format!("TextSynthAPIError: {}", error.error), + retryable: None, + }) + } + _ => { + let error: Error = serde_json::from_slice(c)?; + Err(ModelError { + message: format!("TextSynthAPIError: {}", error.error), + retryable: None, + }) + } + }?; + Ok(response) + } +} + +#[async_trait] +impl LLM for TextSynthLLM { + fn id(&self) -> String { + self.id.clone() + } + + async fn initialize(&mut self, credentials: Credentials) -> Result<()> { + match credentials.get("TEXTSYNTH_API_KEY") { + Some(api_key) => { + self.api_key = Some(api_key.clone()); + } + None => match tokio::task::spawn_blocking(|| std::env::var("TEXTSYNTH_API_KEY")).await? + { + Ok(key) => { + self.api_key = Some(key); + } + Err(_) => Err(anyhow!( + "Credentials or environment variable `TEXTSYNTH_API_KEY` is not set." + ))?, + }, + } + Ok(()) + } + + fn context_size(&self) -> usize { + match self.id.as_str() { + "mistral_7B" => 4096, + "mistral_7B_instruct" => 4096, + "falcon_7B" => 2048, + "falcon_40B" => 2048, + "falcon_40B-chat" => 2048, + "llama2_7B" => 4096, + _ => 2048, + } + } + + async fn encode(&self, text: &str) -> Result> { + assert!(self.api_key.is_some()); + + api_tokenize(self.api_key.as_ref().unwrap(), self.id.as_str(), text).await + } + + async fn decode(&self, _tokens: Vec) -> Result { + Err(anyhow!( + "Encode/Decode not implemented for provider `textsynth`" + )) + } + + // We return empty strings in place of tokens strings to partially support the endpoint. + async fn tokenize(&self, text: &str) -> Result> { + api_tokenize(self.api_key.as_ref().unwrap(), self.id.as_str(), text) + .await? + .into_iter() + .map(|t| Ok((t, String::from("")))) + .collect() + } + + async fn generate( + &self, + prompt: &str, + mut max_tokens: Option, + temperature: f32, + _n: usize, + stop: &Vec, + frequency_penalty: Option, + presence_penalty: Option, + top_p: Option, + _top_logprobs: Option, + _extras: Option, + _event_sender: Option>, + ) -> Result { + assert!(self.api_key.is_some()); + + if let Some(m) = max_tokens { + if m == -1 { + let tokens = self.encode(prompt).await?; + max_tokens = Some((self.context_size() - tokens.len()) as i32); + } + } + + // println!("STOP: {:?}", stop); + + let c = self + .completion( + prompt, + max_tokens, + temperature, + stop, + None, + top_p, + frequency_penalty, + presence_penalty, + None, + None, + ) + .await?; + + // println!("COMPLETION: {:?}", c); + + Ok(LLMGeneration { + created: utils::now(), + provider: ProviderID::TextSynth.to_string(), + model: self.id.clone(), + completions: vec![Tokens { + text: c.text.clone(), + tokens: Some(vec![]), + logprobs: Some(vec![]), + top_logprobs: Some(vec![]), + }], + prompt: Tokens { + text: prompt.to_string(), + tokens: None, + logprobs: None, + top_logprobs: None, + }, + }) + } + + async fn chat( + &self, + _messages: &Vec, + _functions: &Vec, + _function_call: Option, + _temperature: f32, + _top_p: Option, + _n: usize, + _stop: &Vec, + _max_tokens: Option, + _presence_penalty: Option, + _frequency_penalty: Option, + _extras: Option, + _event_sender: Option>, + ) -> Result { + Err(anyhow!( + "Chat capabilties are not implemented for provider `textsynth`" + )) + } +} + +pub struct TextSynthEmbedder { + id: String, +} + +impl TextSynthEmbedder { + pub fn new(id: String) -> Self { + TextSynthEmbedder { id } + } +} + +#[async_trait] +impl Embedder for TextSynthEmbedder { + fn id(&self) -> String { + self.id.clone() + } + + async fn initialize(&mut self, _credentials: Credentials) -> Result<()> { + Err(anyhow!("Embedders not available for provider `textsynth`")) + } + + fn context_size(&self) -> usize { + 2048 + } + fn embedding_size(&self) -> usize { + 2048 + } + + async fn encode(&self, _text: &str) -> Result> { + Err(anyhow!( + "Encode/Decode not implemented for provider `textsynth`" + )) + } + + async fn decode(&self, _tokens: Vec) -> Result { + Err(anyhow!( + "Encode/Decode not implemented for provider `textsynth`" + )) + } + + async fn embed(&self, _text: Vec<&str>, _extras: Option) -> Result> { + Err(anyhow!("Embeddings not available for provider `textsynth`")) + } +} + +pub struct TextSynthProvider {} + +impl TextSynthProvider { + pub fn new() -> Self { + TextSynthProvider {} + } +} + +#[async_trait] +impl Provider for TextSynthProvider { + fn id(&self) -> ProviderID { + ProviderID::TextSynth + } + + fn setup(&self) -> Result<()> { + utils::info("Setting up TextSynth:"); + utils::info(""); + utils::info( + "To use TextSynth's models, you must set the environment variable `TEXTSYNTH_API_KEY`.", + ); + utils::info("Your API key can be found at `https://textsynth.com/settings.html`."); + utils::info(""); + utils::info("Once ready you can check your setup with `dust provider test textsynth`"); + + Ok(()) + } + + async fn test(&self) -> Result<()> { + if !utils::confirm( + "You are about to make a request for 1 token to `mistral_7B` on the TextSynth API.", + )? { + Err(anyhow!("User aborted TextSynth test."))?; + } + + let mut llm = self.llm(String::from("mistral_7B")); + llm.initialize(Credentials::new()).await?; + + let _ = llm + .generate( + "Hello 😊", + Some(1), + 0.7, + 1, + &vec![], + None, + None, + None, + None, + None, + None, + ) + .await?; + + // let t = llm.encode("Hello 😊").await?; + // let d = llm.decode(t).await?; + // assert!(d == "Hello 😊"); + + // let mut embedder = self.embedder(String::from("large")); + // embedder.initialize(Credentials::new()).await?; + + // let _v = embedder.embed("Hello 😊", None).await?; + // println!("EMBEDDING SIZE: {}", v.vector.len()); + + utils::done("Test successfully completed! TextSynth is ready to use."); + + Ok(()) + } + + fn llm(&self, id: String) -> Box { + Box::new(TextSynthLLM::new(id)) + } + + fn embedder(&self, id: String) -> Box { + Box::new(TextSynthEmbedder::new(id)) + } +} diff --git a/core/src/providers/tiktoken/tiktoken.rs b/core/src/providers/tiktoken/tiktoken.rs index 18ee534850d9..d894461270c2 100644 --- a/core/src/providers/tiktoken/tiktoken.rs +++ b/core/src/providers/tiktoken/tiktoken.rs @@ -145,10 +145,8 @@ pub async fn encode_async(bpe: Arc>, text: &str) -> Result>, - text: String, -) -> Result> { +pub async fn tokenize_async(bpe: Arc>, text: &str) -> Result> { + let text = text.to_string(); let r = task::spawn_blocking(move || bpe.lock().tokenize(&text)).await?; Ok(r) } @@ -643,7 +641,7 @@ impl CoreBPE { self._encode_native(text, &allowed_special).0 } - pub fn tokenize(&self, text: &String) -> Vec<(usize, String)> { + pub fn tokenize(&self, text: &str) -> Vec<(usize, String)> { let allowed_special = self .special_tokens_encoder .keys() @@ -842,7 +840,7 @@ mod tests { #[tokio::test] async fn tokenize_test() { - async fn run_tokenize_test(soupinou: String, expected_soupinou: Vec<(usize, String)>) { + async fn run_tokenize_test(soupinou: &str, expected_soupinou: Vec<(usize, String)>) { let bpe = p50k_base_singleton(); let res = tokenize_async(bpe, soupinou).await; assert_eq!(res.unwrap(), expected_soupinou); @@ -857,7 +855,7 @@ mod tests { (259, "in".to_string()), (280, "ou".to_string()), ]; - run_tokenize_test(regular, expected_regular).await; + run_tokenize_test(®ular, expected_regular).await; let unicode = "Soupinou 🤗".to_string(); let expected_unicode: Vec<(usize, String)> = vec![ @@ -870,7 +868,7 @@ mod tests { (245, "�".to_string()), ]; - run_tokenize_test(unicode, expected_unicode).await; + run_tokenize_test(&unicode, expected_unicode).await; let japanese = "ほこり".to_string(); let expected_japanese: Vec<(usize, String)> = vec![ @@ -880,6 +878,6 @@ mod tests { (28255, "り".to_string()), ]; - run_tokenize_test(japanese, expected_japanese).await; + run_tokenize_test(&japanese, expected_japanese).await; } } diff --git a/front/components/providers/TextSynthSetup.tsx b/front/components/providers/TextSynthSetup.tsx new file mode 100644 index 000000000000..09ad937308b0 --- /dev/null +++ b/front/components/providers/TextSynthSetup.tsx @@ -0,0 +1,208 @@ +import { Button } from "@dust-tt/sparkle"; +import { Dialog, Transition } from "@headlessui/react"; +import { Fragment, useEffect, useState } from "react"; +import { useSWRConfig } from "swr"; + +import { checkProvider } from "@app/lib/providers"; +import { WorkspaceType } from "@app/types/user"; + +export default function TextSynthSetup({ + owner, + open, + setOpen, + config, + enabled, +}: { + owner: WorkspaceType; + open: boolean; + setOpen: (open: boolean) => void; + config: { [key: string]: string }; + enabled: boolean; +}) { + const { mutate } = useSWRConfig(); + + const [apiKey, setApiKey] = useState(config ? config.api_key : ""); + const [testSuccessful, setTestSuccessful] = useState(false); + const [testRunning, setTestRunning] = useState(false); + const [testError, setTestError] = useState(""); + const [enableRunning, setEnableRunning] = useState(false); + + useEffect(() => { + if (config && config.api_key.length > 0 && apiKey.length == 0) { + setApiKey(config.api_key); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [config]); + + const runTest = async () => { + setTestRunning(true); + setTestError(""); + const check = await checkProvider(owner, "textsynth", { + api_key: apiKey, + }); + + if (!check.ok) { + setTestError(check.error || "Unknown error"); + setTestSuccessful(false); + setTestRunning(false); + } else { + setTestError(""); + setTestSuccessful(true); + setTestRunning(false); + } + }; + + const handleEnable = async () => { + setEnableRunning(true); + const res = await fetch(`/api/w/${owner.sId}/providers/textsynth`, { + headers: { + "Content-Type": "application/json", + }, + method: "POST", + body: JSON.stringify({ + config: JSON.stringify({ + api_key: apiKey, + }), + }), + }); + await res.json(); + setEnableRunning(false); + setOpen(false); + await mutate(`/api/w/${owner.sId}/providers`); + }; + + const handleDisable = async () => { + const res = await fetch(`/api/w/${owner.sId}/providers/textsynth`, { + method: "DELETE", + }); + await res.json(); + setOpen(false); + await mutate(`/api/w/${owner.sId}/providers`); + }; + + return ( + + setOpen(false)}> + +
+ + +
+
+ + +
+
+ + Setup TextSynth + +
+

+ To use TextSynth models you must provide your API key. + It can be found{" "} + + here + +

+

+ We'll never use your API key for anything other than to + run your apps. +

+
+
+ { + setApiKey(e.target.value); + setTestSuccessful(false); + }} + /> +
+
+
+
+ {testError.length > 0 ? ( + Error: {testError} + ) : testSuccessful ? ( + + Test succeeded! You can enable TextSynth. + + ) : ( +   + )} +
+
+ {enabled ? ( +
handleDisable()} + > + Disable +
+ ) : ( + <> + )} +
+
+ {/* TODO: typescript */} +
+
+ {testSuccessful ? ( +
+
+
+
+
+
+
+
+ ); +} diff --git a/front/documents_post_process_hooks/hooks/document_tracker/suggest_changes/lib.ts b/front/documents_post_process_hooks/hooks/document_tracker/suggest_changes/lib.ts index 7efbcb16f7f8..9c4b70158f3f 100644 --- a/front/documents_post_process_hooks/hooks/document_tracker/suggest_changes/lib.ts +++ b/front/documents_post_process_hooks/hooks/document_tracker/suggest_changes/lib.ts @@ -226,8 +226,8 @@ export async function documentTrackerSuggestChangesOnUpsert({ const tokensInDiff = await CoreAPI.tokenize({ text: diffText, - modelId: "text-embedding-ada-002", providerId: "openai", + modelId: "gpt-3.5-turbo", }); if (tokensInDiff.isErr()) { throw tokensInDiff.error; diff --git a/front/lib/actions/registry.ts b/front/lib/actions/registry.ts index c9fcf5992bf3..6e124d06b7b4 100644 --- a/front/lib/actions/registry.ts +++ b/front/lib/actions/registry.ts @@ -78,87 +78,6 @@ export const DustProdActionRegistry = createActionRegistry({ }, }, - "chat-retrieval": { - app: { - workspaceId: PRODUCTION_DUST_APPS_WORKSPACE_ID, - appId: "0d7ab66fd2", - appHash: - "63d4bea647370f23fa396dc59347cfbd92354bced26783c9a99812a8b1b14371", - }, - config: { - DATASOURCE: { - data_sources: [], - top_k: 16, - filter: { tags: null, timestamp: null }, - use_cache: false, - }, - }, - }, - "chat-assistant": { - app: { - workspaceId: PRODUCTION_DUST_APPS_WORKSPACE_ID, - appId: "ab43ff2450", - appHash: - "5ba93a4b1750336ff15614b16d2f735de444e63ff22ec03f8c6e8b48392e0ea5", - }, - config: { - MODEL: { - provider_id: "openai", - model_id: "gpt-3.5-turbo", - use_cache: true, - use_stream: true, - }, - }, - }, - "chat-assistant-wfn": { - app: { - workspaceId: PRODUCTION_DUST_APPS_WORKSPACE_ID, - appId: "0052be4be7", - appHash: - "e2291ceaa774bf3a05d1c3a20db6b6de613070ae683704d6e3076d2711755d81", - }, - config: { - MODEL: { - provider_id: "openai", - model_id: "gpt-4-0613", - function_call: "auto", - use_cache: true, - use_stream: true, - }, - }, - }, - "chat-message-e2e-eval": { - app: { - workspaceId: PRODUCTION_DUST_APPS_WORKSPACE_ID, - appId: "201e6de608", - appHash: - "f676a2de5b83bdbd02b55beaefddafb0bb16f53715d9d7ba3c219ec09b6b0588", - }, - config: { - RULE_VALIDITY: { - provider_id: "openai", - model_id: "gpt-3.5-turbo-0613", - function_call: "send_rule_validity", - use_cache: true, - }, - }, - }, - "chat-title": { - app: { - workspaceId: PRODUCTION_DUST_APPS_WORKSPACE_ID, - appId: "8fe968eef3", - appHash: - "528251ebc7b5bc59027877842ca6ff05c64c08765c3ab1f5e1b799395cdb3b57", - }, - config: { - TITLE_CHAT: { - provider_id: "openai", - model_id: "gpt-3.5-turbo-0613", - function_call: "post_title", - use_cache: true, - }, - }, - }, "doc-tracker-retrieval": { app: { workspaceId: PRODUCTION_DUST_APPS_WORKSPACE_ID, @@ -192,7 +111,7 @@ export const DustProdActionRegistry = createActionRegistry({ config: { SUGGEST_CHANGES: { provider_id: "openai", - model_id: "gpt-4-0613", + model_id: "gpt-4", use_cache: false, function_call: "suggest_changes", }, @@ -208,7 +127,7 @@ export const DustProdActionRegistry = createActionRegistry({ config: { MODEL: { provider_id: "openai", - model_id: "gpt-4-0613", + model_id: "gpt-4", use_cache: false, function_call: "extract_events", }, diff --git a/front/lib/api/credentials.ts b/front/lib/api/credentials.ts index 43cef35f2550..f50f3cdc2632 100644 --- a/front/lib/api/credentials.ts +++ b/front/lib/api/credentials.ts @@ -1,7 +1,10 @@ import { CredentialsType, ProviderType } from "@app/types/provider"; -const { DUST_MANAGED_OPENAI_API_KEY = "", DUST_MANAGED_ANTHROPIC_API_KEY } = - process.env; +const { + DUST_MANAGED_OPENAI_API_KEY = "", + DUST_MANAGED_ANTHROPIC_API_KEY = "", + DUST_MANAGED_TEXTSYNTH_API_KEY = "", +} = process.env; export const credentialsFromProviders = ( providers: ProviderType[] @@ -31,6 +34,9 @@ export const credentialsFromProviders = ( case "anthropic": credentials["ANTHROPIC_API_KEY"] = config.api_key; break; + case "textsynth": + credentials["TEXTSYNTH_API_KEY"] = config.api_key; + break; case "serpapi": credentials["SERP_API_KEY"] = config.api_key; break; @@ -49,5 +55,6 @@ export const dustManagedCredentials = (): CredentialsType => { return { OPENAI_API_KEY: DUST_MANAGED_OPENAI_API_KEY, ANTHROPIC_API_KEY: DUST_MANAGED_ANTHROPIC_API_KEY, + TEXTSYNTH_API_KEY: DUST_MANAGED_TEXTSYNTH_API_KEY, }; }; diff --git a/front/lib/extract_event_app.ts b/front/lib/extract_event_app.ts index 2eaa9a79efcf..52b70473050a 100644 --- a/front/lib/extract_event_app.ts +++ b/front/lib/extract_event_app.ts @@ -167,11 +167,10 @@ function extractMaxTokens({ export async function getTokenizedText( text: string ): Promise<{ tokens: CoreAPITokenType[] }> { - console.log("computeNbTokens4"); const tokenizeResponse = await CoreAPI.tokenize({ text: text, - modelId: "text-embedding-ada-002", providerId: "openai", + modelId: "gpt-4", }); if (tokenizeResponse.isErr()) { { diff --git a/front/lib/providers.ts b/front/lib/providers.ts index 7fa354d6f15d..baf0dd974d4d 100644 --- a/front/lib/providers.ts +++ b/front/lib/providers.ts @@ -53,6 +53,14 @@ export const modelProviders: ModelProvider[] = [ chat: true, embed: false, }, + { + providerId: "textsynth", + name: "TextSynth", + built: true, + enabled: false, + chat: false, + embed: false, + }, { providerId: "hugging_face", name: "Hugging Face", diff --git a/front/package-lock.json b/front/package-lock.json index d9f00b2e9e93..042ed1ebd24a 100644 --- a/front/package-lock.json +++ b/front/package-lock.json @@ -5,7 +5,7 @@ "packages": { "": { "dependencies": { - "@dust-tt/sparkle": "0.1.94", + "@dust-tt/sparkle": "0.1.95", "@headlessui/react": "^1.7.7", "@heroicons/react": "^2.0.11", "@nangohq/frontend": "^0.16.1", @@ -722,9 +722,9 @@ "license": "Apache-2.0" }, "node_modules/@dust-tt/sparkle": { - "version": "0.1.94", - "resolved": "https://registry.npmjs.org/@dust-tt/sparkle/-/sparkle-0.1.94.tgz", - "integrity": "sha512-sT3ygsaUKVm4WniYw6x5VUsZ7Dr5K37FvU6jwWWaBTEzCkROCVHwOSAfbhsx8zkvQp+orBREZ+yZmhGYWvvjlQ==", + "version": "0.1.95", + "resolved": "https://registry.npmjs.org/@dust-tt/sparkle/-/sparkle-0.1.95.tgz", + "integrity": "sha512-id1/gka5Z7Ql4A7tCwGzrHfgoigyT7gTU4QZ6omwlriWE3giD3QRon54sSAF6aMWhYa9h/XS0ctbWM7vkNSnaQ==", "dependencies": { "@headlessui/react": "^1.7.17" }, diff --git a/front/package.json b/front/package.json index ecec6ebd676e..09ec2d7a70bb 100644 --- a/front/package.json +++ b/front/package.json @@ -13,7 +13,7 @@ "initdb": "env $(cat .env.local) npx tsx admin/db.ts" }, "dependencies": { - "@dust-tt/sparkle": "0.1.94", + "@dust-tt/sparkle": "0.1.95", "@headlessui/react": "^1.7.7", "@heroicons/react": "^2.0.11", "@nangohq/frontend": "^0.16.1", diff --git a/front/pages/api/w/[wId]/providers/[pId]/check.ts b/front/pages/api/w/[wId]/providers/[pId]/check.ts index fe2f01cbdc30..c0e20658f378 100644 --- a/front/pages/api/w/[wId]/providers/[pId]/check.ts +++ b/front/pages/api/w/[wId]/providers/[pId]/check.ts @@ -144,6 +144,31 @@ async function handler( } return; + case "textsynth": + const testCompletion = await fetch( + "https://api.textsynth.com/v1/engines/mistral_7B/completions", + { + method: "POST", + headers: { + Authorization: `Bearer ${config.api_key}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + prompt: "", + max_tokens: 1, + }), + } + ); + + if (!testCompletion.ok) { + const err = await testCompletion.json(); + res.status(400).json({ ok: false, error: err.error }); + } else { + await testCompletion.json(); + res.status(200).json({ ok: true }); + } + return; + case "serpapi": const testSearch = await fetch( `https://serpapi.com/search?engine=google&q=Coffee&api_key=${config.api_key}`, diff --git a/front/pages/api/w/[wId]/providers/[pId]/models.ts b/front/pages/api/w/[wId]/providers/[pId]/models.ts index bacaaf6f25ba..45a664dac994 100644 --- a/front/pages/api/w/[wId]/providers/[pId]/models.ts +++ b/front/pages/api/w/[wId]/providers/[pId]/models.ts @@ -200,6 +200,27 @@ async function handler( res.status(200).json({ models: anthropic_models }); return; + case "textsynth": + if (chat) { + res.status(200).json({ + models: [ + // { id: "mistral_7B_instruct" }, + // { id: "falcon_40B-chat" }, + ], + }); + return; + } + res.status(200).json({ + models: [ + { id: "mistral_7B" }, + { id: "mistral_7B_instruct" }, + { id: "falcon_7B" }, + { id: "falcon_40B" }, + { id: "llama2_7B" }, + ], + }); + return; + default: res.status(404).json({ error: "Provider not found" }); return; diff --git a/front/pages/w/[wId]/a/[aId]/settings.tsx b/front/pages/w/[wId]/a/[aId]/settings.tsx index 8f872e29df01..fa0cdc8d258f 100644 --- a/front/pages/w/[wId]/a/[aId]/settings.tsx +++ b/front/pages/w/[wId]/a/[aId]/settings.tsx @@ -109,7 +109,7 @@ export default function SettingsView({ method: "DELETE", }); if (res.ok) { - await router.push(`/w/${owner.sId}/`); + await router.push(`/w/${owner.sId}/a`); } else { setIsDeleting(false); const err = (await res.json()) as { error: APIError }; diff --git a/front/pages/w/[wId]/a/index.tsx b/front/pages/w/[wId]/a/index.tsx index 202ab3dc18b6..4f61da4f4ce0 100644 --- a/front/pages/w/[wId]/a/index.tsx +++ b/front/pages/w/[wId]/a/index.tsx @@ -20,6 +20,7 @@ import CohereSetup from "@app/components/providers/CohereSetup"; import OpenAISetup from "@app/components/providers/OpenAISetup"; import SerpAPISetup from "@app/components/providers/SerpAPISetup"; import SerperSetup from "@app/components/providers/SerperSetup"; +import TextSynthSetup from "@app/components/providers/TextSynthSetup"; import AppLayout from "@app/components/sparkle/AppLayout"; import { subNavigationAdmin } from "@app/components/sparkle/navigation"; import { getApps } from "@app/lib/api/app"; @@ -203,6 +204,7 @@ export function Providers({ owner }: { owner: WorkspaceType }) { const [ai21Open, setAI21Open] = useState(false); const [azureOpenAIOpen, setAzureOpenAIOpen] = useState(false); const [anthropicOpen, setAnthropicOpen] = useState(false); + const [textSynthOpen, setTextSynthOpen] = useState(false); const [serpapiOpen, setSerpapiOpen] = useState(false); const [serperOpen, setSerperOpen] = useState(false); const [browserlessapiOpen, setBrowserlessapiOpen] = useState(false); @@ -255,6 +257,13 @@ export function Providers({ owner }: { owner: WorkspaceType }) { enabled={configs["anthropic"] ? true : false} config={configs["anthropic"] ? configs["anthropic"] : null} /> + 80 ? "w-[38em]" : "s-whitespace-nowrap"; + const labelClasses = labelLength > 80 ? "s-w-[38em]" : "s-whitespace-nowrap"; return (
(
); + +export const TooltipLongLabel = () => ( +
+ + Hover me + +
+);