diff --git a/core/src/lib.rs b/core/src/lib.rs index b91e0705dad9..0a49058d4f72 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -30,6 +30,7 @@ pub mod providers { pub mod tiktoken; } pub mod anthropic; + pub mod textsynth; } pub mod http { pub mod request; diff --git a/core/src/providers/provider.rs b/core/src/providers/provider.rs index a6f7ca78a321..0ad36b27b51c 100644 --- a/core/src/providers/provider.rs +++ b/core/src/providers/provider.rs @@ -13,6 +13,8 @@ use serde::{Deserialize, Serialize}; use std::str::FromStr; use std::time::Duration; +use super::textsynth::TextSynthProvider; + #[derive(Debug, Clone, Copy, Serialize, PartialEq, Deserialize)] #[serde(rename_all = "lowercase")] pub enum ProviderID { @@ -22,6 +24,7 @@ pub enum ProviderID { #[serde(rename = "azure_openai")] AzureOpenAI, Anthropic, + TextSynth, } impl ToString for ProviderID { @@ -32,6 +35,7 @@ impl ToString for ProviderID { ProviderID::AI21 => String::from("ai21"), ProviderID::AzureOpenAI => String::from("azure_openai"), ProviderID::Anthropic => String::from("anthropic"), + ProviderID::TextSynth => String::from("textsynth"), } } } @@ -45,6 +49,7 @@ impl FromStr for ProviderID { "ai21" => Ok(ProviderID::AI21), "azure_openai" => Ok(ProviderID::AzureOpenAI), "anthropic" => Ok(ProviderID::Anthropic), + "textsynth" => Ok(ProviderID::TextSynth), _ => Err(ParseError::with_message( "Unknown provider ID (possible values: openai, cohere, ai21, azure_openai)", ))?, @@ -139,5 +144,6 @@ pub fn provider(t: ProviderID) -> Box { ProviderID::AI21 => Box::new(AI21Provider::new()), ProviderID::AzureOpenAI => Box::new(AzureOpenAIProvider::new()), ProviderID::Anthropic => Box::new(AnthropicProvider::new()), + ProviderID::TextSynth => Box::new(TextSynthProvider::new()), } } diff --git a/core/src/providers/textsynth.rs b/core/src/providers/textsynth.rs new file mode 100644 index 000000000000..5c316f8d1660 --- /dev/null +++ b/core/src/providers/textsynth.rs @@ -0,0 +1,482 @@ +use crate::providers::embedder::Embedder; +use crate::providers::llm::Tokens; +use crate::providers::llm::{ChatMessage, LLMChatGeneration, LLMGeneration, LLM}; +use crate::providers::provider::{ModelError, ModelErrorRetryOptions, Provider, ProviderID}; +use crate::run::Credentials; +use crate::utils; +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use hyper::{body::Buf, Body, Client, Method, Request, Uri}; +use hyper_tls::HttpsConnector; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use std::io::prelude::*; +use std::time::Duration; +use tokio::sync::mpsc::UnboundedSender; + +use super::embedder::EmbedderVector; +use super::llm::ChatFunction; + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Error { + pub error: String, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct TokenizeResponse { + pub tokens: Vec, +} + +async fn api_tokenize(api_key: &str, engine: &str, text: &str) -> Result> { + let https = HttpsConnector::new(); + let cli = Client::builder().build::<_, hyper::Body>(https); + + let req = Request::builder() + .method(Method::POST) + .uri(format!("https://api.textsynth.com/v1/engines/{}/tokenize", engine).parse::()?) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body(Body::from( + json!({ + "text": text, + }) + .to_string(), + ))?; + + let res = cli.request(req).await?; + let status = res.status(); + let body = hyper::body::aggregate(res).await?; + let mut b: Vec = vec![]; + body.reader().read_to_end(&mut b)?; + let c: &[u8] = &b; + + let r = match status { + hyper::StatusCode::OK => { + let r: TokenizeResponse = serde_json::from_slice(c)?; + Ok(r) + } + hyper::StatusCode::TOO_MANY_REQUESTS => { + let error: Error = serde_json::from_slice(c).unwrap_or(Error { + error: "Too many requests".to_string(), + }); + Err(ModelError { + message: format!("TextSynthAPIError: {}", error.error), + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(2000), + factor: 2, + retries: 8, + }), + }) + } + hyper::StatusCode::BAD_REQUEST => { + let error: Error = serde_json::from_slice(c).unwrap_or(Error { + error: "Unknown error".to_string(), + }); + Err(ModelError { + message: format!("TextSynthAPIError: {}", error.error), + retryable: None, + }) + } + _ => { + let error: Error = serde_json::from_slice(c)?; + Err(ModelError { + message: format!("TextSynthAPIError: {}", error.error), + retryable: None, + }) + } + }?; + Ok(r.tokens) +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Completion { + pub text: String, + pub reached_end: bool, + pub input_tokens: usize, + pub output_tokens: usize, +} + +pub struct TextSynthLLM { + id: String, + api_key: Option, +} + +impl TextSynthLLM { + pub fn new(id: String) -> Self { + TextSynthLLM { id, api_key: None } + } + + fn uri(&self) -> Result { + Ok(format!( + "https://api.textsynth.com/v1/engines/{}/completions", + self.id + ) + .parse::()?) + } + + // fn chat_uri(&self) -> Result { + // Ok(format!("https://api.textsynth.com/v1/engines/{}/chat", self.id).parse::()?) + // } + + async fn completion( + &self, + prompt: &str, + max_tokens: Option, + temperature: f32, + stop: &Vec, + top_k: Option, + top_p: Option, + frequency_penalty: Option, + presence_penalty: Option, + repetition_penalty: Option, + typical_p: Option, + ) -> Result { + assert!(self.api_key.is_some()); + + let https = HttpsConnector::new(); + let cli = Client::builder().build::<_, hyper::Body>(https); + + let mut body = json!({ + "prompt": prompt, + "temperature": temperature, + }); + if max_tokens.is_some() { + body["max_tokens"] = json!(max_tokens.unwrap()); + } + if stop.len() > 0 { + body["stop"] = json!(stop); + } + if top_k.is_some() { + body["top_k"] = json!(top_k.unwrap()); + } + if top_p.is_some() { + body["top_p"] = json!(top_k.unwrap()); + } + if frequency_penalty.is_some() { + body["frequency_penalty"] = json!(frequency_penalty.unwrap()); + } + if presence_penalty.is_some() { + body["presence_penalty"] = json!(presence_penalty.unwrap()); + } + if repetition_penalty.is_some() { + body["repetition_penalty"] = json!(repetition_penalty.unwrap()); + } + if typical_p.is_some() { + body["typical_p"] = json!(typical_p.unwrap()); + } + + let req = Request::builder() + .method(Method::POST) + .uri(self.uri()?) + .header("Content-Type", "application/json") + .header( + "Authorization", + format!("Bearer {}", self.api_key.clone().unwrap()), + ) + .body(Body::from(body.to_string()))?; + + let res = cli.request(req).await?; + let status = res.status(); + let body = hyper::body::aggregate(res).await?; + let mut b: Vec = vec![]; + body.reader().read_to_end(&mut b)?; + let c: &[u8] = &b; + + let response = match status { + hyper::StatusCode::OK => { + let completion: Completion = serde_json::from_slice(c)?; + Ok(completion) + } + hyper::StatusCode::TOO_MANY_REQUESTS => { + let error: Error = serde_json::from_slice(c).unwrap_or(Error { + error: "Too many requests".to_string(), + }); + Err(ModelError { + message: format!("TextSynthAPIError: {}", error.error), + retryable: Some(ModelErrorRetryOptions { + sleep: Duration::from_millis(2000), + factor: 2, + retries: 8, + }), + }) + } + hyper::StatusCode::BAD_REQUEST => { + let error: Error = serde_json::from_slice(c).unwrap_or(Error { + error: "Unknown error".to_string(), + }); + Err(ModelError { + message: format!("TextSynthAPIError: {}", error.error), + retryable: None, + }) + } + _ => { + let error: Error = serde_json::from_slice(c)?; + Err(ModelError { + message: format!("TextSynthAPIError: {}", error.error), + retryable: None, + }) + } + }?; + Ok(response) + } +} + +#[async_trait] +impl LLM for TextSynthLLM { + fn id(&self) -> String { + self.id.clone() + } + + async fn initialize(&mut self, credentials: Credentials) -> Result<()> { + match credentials.get("TEXTSYNTH_API_KEY") { + Some(api_key) => { + self.api_key = Some(api_key.clone()); + } + None => match tokio::task::spawn_blocking(|| std::env::var("TEXTSYNTH_API_KEY")).await? + { + Ok(key) => { + self.api_key = Some(key); + } + Err(_) => Err(anyhow!( + "Credentials or environment variable `TEXTSYNTH_API_KEY` is not set." + ))?, + }, + } + Ok(()) + } + + fn context_size(&self) -> usize { + match self.id.as_str() { + "mistral_7B" => 4096, + "mistral_7B_instruct" => 4096, + "falcon_7B" => 2048, + "falcon_40B" => 2048, + "falcon_40B-chat" => 2048, + "llama2_7B" => 4096, + _ => 2048, + } + } + + async fn encode(&self, text: &str) -> Result> { + assert!(self.api_key.is_some()); + + api_tokenize( + self.api_key.clone().unwrap().as_str(), + self.id.as_str(), + text, + ) + .await + } + + async fn decode(&self, _tokens: Vec) -> Result { + Err(anyhow!( + "Encode/Decode not implemented for provider `textsynth`" + )) + } + + async fn generate( + &self, + prompt: &str, + mut max_tokens: Option, + temperature: f32, + _n: usize, + stop: &Vec, + frequency_penalty: Option, + presence_penalty: Option, + top_p: Option, + _top_logprobs: Option, + _extras: Option, + _event_sender: Option>, + ) -> Result { + assert!(self.api_key.is_some()); + + if let Some(m) = max_tokens { + if m == -1 { + let tokens = self.encode(prompt).await?; + max_tokens = Some((self.context_size() - tokens.len()) as i32); + } + } + + // println!("STOP: {:?}", stop); + + let c = self + .completion( + prompt, + max_tokens, + temperature, + stop, + None, + top_p, + frequency_penalty, + presence_penalty, + None, + None, + ) + .await?; + + // println!("COMPLETION: {:?}", c); + + Ok(LLMGeneration { + created: utils::now(), + provider: ProviderID::TextSynth.to_string(), + model: self.id.clone(), + completions: vec![Tokens { + text: c.text.clone(), + tokens: Some(vec![]), + logprobs: Some(vec![]), + top_logprobs: Some(vec![]), + }], + prompt: Tokens { + text: prompt.to_string(), + tokens: None, + logprobs: None, + top_logprobs: None, + }, + }) + } + + async fn chat( + &self, + _messages: &Vec, + _functions: &Vec, + _function_call: Option, + _temperature: f32, + _top_p: Option, + _n: usize, + _stop: &Vec, + _max_tokens: Option, + _presence_penalty: Option, + _frequency_penalty: Option, + _extras: Option, + _event_sender: Option>, + ) -> Result { + Err(anyhow!( + "Chat capabilties are not implemented for provider `textsynth`" + )) + } +} + +pub struct TextSynthEmbedder { + id: String, +} + +impl TextSynthEmbedder { + pub fn new(id: String) -> Self { + TextSynthEmbedder { id } + } +} + +#[async_trait] +impl Embedder for TextSynthEmbedder { + fn id(&self) -> String { + self.id.clone() + } + + async fn initialize(&mut self, _credentials: Credentials) -> Result<()> { + Err(anyhow!("Embedders not available for provider `textsynth`")) + } + + fn context_size(&self) -> usize { + 2048 + } + fn embedding_size(&self) -> usize { + 2048 + } + + async fn encode(&self, _text: &str) -> Result> { + Err(anyhow!( + "Encode/Decode not implemented for provider `textsynth`" + )) + } + + async fn decode(&self, _tokens: Vec) -> Result { + Err(anyhow!( + "Encode/Decode not implemented for provider `textsynth`" + )) + } + + async fn tokenize(&self, _text: String) -> Result> { + Err(anyhow!("Tokenize not implemented for provider `textsynth`")) + } + + async fn embed(&self, _text: Vec<&str>, _extras: Option) -> Result> { + Err(anyhow!("Embeddings not available for provider `textsynth`")) + } +} + +pub struct TextSynthProvider {} + +impl TextSynthProvider { + pub fn new() -> Self { + TextSynthProvider {} + } +} + +#[async_trait] +impl Provider for TextSynthProvider { + fn id(&self) -> ProviderID { + ProviderID::TextSynth + } + + fn setup(&self) -> Result<()> { + utils::info("Setting up TextSynth:"); + utils::info(""); + utils::info( + "To use TextSynth's models, you must set the environment variable `TEXTSYNTH_API_KEY`.", + ); + utils::info("Your API key can be found at `https://textsynth.com/settings.html`."); + utils::info(""); + utils::info("Once ready you can check your setup with `dust provider test textsynth`"); + + Ok(()) + } + + async fn test(&self) -> Result<()> { + if !utils::confirm( + "You are about to make a request for 1 token to `mistral_7B` on the TextSynth API.", + )? { + Err(anyhow!("User aborted TextSynth test."))?; + } + + let mut llm = self.llm(String::from("mistral_7B")); + llm.initialize(Credentials::new()).await?; + + let _ = llm + .generate( + "Hello 😊", + Some(1), + 0.7, + 1, + &vec![], + None, + None, + None, + None, + None, + None, + ) + .await?; + + // let t = llm.encode("Hello 😊").await?; + // let d = llm.decode(t).await?; + // assert!(d == "Hello 😊"); + + // let mut embedder = self.embedder(String::from("large")); + // embedder.initialize(Credentials::new()).await?; + + // let _v = embedder.embed("Hello 😊", None).await?; + // println!("EMBEDDING SIZE: {}", v.vector.len()); + + utils::done("Test successfully completed! TextSynth is ready to use."); + + Ok(()) + } + + fn llm(&self, id: String) -> Box { + Box::new(TextSynthLLM::new(id)) + } + + fn embedder(&self, id: String) -> Box { + Box::new(TextSynthEmbedder::new(id)) + } +} diff --git a/front/components/providers/TextSynthSetup.tsx b/front/components/providers/TextSynthSetup.tsx new file mode 100644 index 000000000000..09ad937308b0 --- /dev/null +++ b/front/components/providers/TextSynthSetup.tsx @@ -0,0 +1,208 @@ +import { Button } from "@dust-tt/sparkle"; +import { Dialog, Transition } from "@headlessui/react"; +import { Fragment, useEffect, useState } from "react"; +import { useSWRConfig } from "swr"; + +import { checkProvider } from "@app/lib/providers"; +import { WorkspaceType } from "@app/types/user"; + +export default function TextSynthSetup({ + owner, + open, + setOpen, + config, + enabled, +}: { + owner: WorkspaceType; + open: boolean; + setOpen: (open: boolean) => void; + config: { [key: string]: string }; + enabled: boolean; +}) { + const { mutate } = useSWRConfig(); + + const [apiKey, setApiKey] = useState(config ? config.api_key : ""); + const [testSuccessful, setTestSuccessful] = useState(false); + const [testRunning, setTestRunning] = useState(false); + const [testError, setTestError] = useState(""); + const [enableRunning, setEnableRunning] = useState(false); + + useEffect(() => { + if (config && config.api_key.length > 0 && apiKey.length == 0) { + setApiKey(config.api_key); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [config]); + + const runTest = async () => { + setTestRunning(true); + setTestError(""); + const check = await checkProvider(owner, "textsynth", { + api_key: apiKey, + }); + + if (!check.ok) { + setTestError(check.error || "Unknown error"); + setTestSuccessful(false); + setTestRunning(false); + } else { + setTestError(""); + setTestSuccessful(true); + setTestRunning(false); + } + }; + + const handleEnable = async () => { + setEnableRunning(true); + const res = await fetch(`/api/w/${owner.sId}/providers/textsynth`, { + headers: { + "Content-Type": "application/json", + }, + method: "POST", + body: JSON.stringify({ + config: JSON.stringify({ + api_key: apiKey, + }), + }), + }); + await res.json(); + setEnableRunning(false); + setOpen(false); + await mutate(`/api/w/${owner.sId}/providers`); + }; + + const handleDisable = async () => { + const res = await fetch(`/api/w/${owner.sId}/providers/textsynth`, { + method: "DELETE", + }); + await res.json(); + setOpen(false); + await mutate(`/api/w/${owner.sId}/providers`); + }; + + return ( + + setOpen(false)}> + +
+ + +
+
+ + +
+
+ + Setup TextSynth + +
+

+ To use TextSynth models you must provide your API key. + It can be found{" "} + + here + +

+

+ We'll never use your API key for anything other than to + run your apps. +

+
+
+ { + setApiKey(e.target.value); + setTestSuccessful(false); + }} + /> +
+
+
+
+ {testError.length > 0 ? ( + Error: {testError} + ) : testSuccessful ? ( + + Test succeeded! You can enable TextSynth. + + ) : ( +   + )} +
+
+ {enabled ? ( +
handleDisable()} + > + Disable +
+ ) : ( + <> + )} +
+
+ {/* TODO: typescript */} +
+
+ {testSuccessful ? ( +
+
+
+
+
+
+
+
+ ); +} diff --git a/front/lib/api/credentials.ts b/front/lib/api/credentials.ts index 43cef35f2550..f50f3cdc2632 100644 --- a/front/lib/api/credentials.ts +++ b/front/lib/api/credentials.ts @@ -1,7 +1,10 @@ import { CredentialsType, ProviderType } from "@app/types/provider"; -const { DUST_MANAGED_OPENAI_API_KEY = "", DUST_MANAGED_ANTHROPIC_API_KEY } = - process.env; +const { + DUST_MANAGED_OPENAI_API_KEY = "", + DUST_MANAGED_ANTHROPIC_API_KEY = "", + DUST_MANAGED_TEXTSYNTH_API_KEY = "", +} = process.env; export const credentialsFromProviders = ( providers: ProviderType[] @@ -31,6 +34,9 @@ export const credentialsFromProviders = ( case "anthropic": credentials["ANTHROPIC_API_KEY"] = config.api_key; break; + case "textsynth": + credentials["TEXTSYNTH_API_KEY"] = config.api_key; + break; case "serpapi": credentials["SERP_API_KEY"] = config.api_key; break; @@ -49,5 +55,6 @@ export const dustManagedCredentials = (): CredentialsType => { return { OPENAI_API_KEY: DUST_MANAGED_OPENAI_API_KEY, ANTHROPIC_API_KEY: DUST_MANAGED_ANTHROPIC_API_KEY, + TEXTSYNTH_API_KEY: DUST_MANAGED_TEXTSYNTH_API_KEY, }; }; diff --git a/front/lib/providers.ts b/front/lib/providers.ts index 7fa354d6f15d..baf0dd974d4d 100644 --- a/front/lib/providers.ts +++ b/front/lib/providers.ts @@ -53,6 +53,14 @@ export const modelProviders: ModelProvider[] = [ chat: true, embed: false, }, + { + providerId: "textsynth", + name: "TextSynth", + built: true, + enabled: false, + chat: false, + embed: false, + }, { providerId: "hugging_face", name: "Hugging Face", diff --git a/front/pages/api/w/[wId]/providers/[pId]/check.ts b/front/pages/api/w/[wId]/providers/[pId]/check.ts index fe2f01cbdc30..c0e20658f378 100644 --- a/front/pages/api/w/[wId]/providers/[pId]/check.ts +++ b/front/pages/api/w/[wId]/providers/[pId]/check.ts @@ -144,6 +144,31 @@ async function handler( } return; + case "textsynth": + const testCompletion = await fetch( + "https://api.textsynth.com/v1/engines/mistral_7B/completions", + { + method: "POST", + headers: { + Authorization: `Bearer ${config.api_key}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + prompt: "", + max_tokens: 1, + }), + } + ); + + if (!testCompletion.ok) { + const err = await testCompletion.json(); + res.status(400).json({ ok: false, error: err.error }); + } else { + await testCompletion.json(); + res.status(200).json({ ok: true }); + } + return; + case "serpapi": const testSearch = await fetch( `https://serpapi.com/search?engine=google&q=Coffee&api_key=${config.api_key}`, diff --git a/front/pages/api/w/[wId]/providers/[pId]/models.ts b/front/pages/api/w/[wId]/providers/[pId]/models.ts index bacaaf6f25ba..45a664dac994 100644 --- a/front/pages/api/w/[wId]/providers/[pId]/models.ts +++ b/front/pages/api/w/[wId]/providers/[pId]/models.ts @@ -200,6 +200,27 @@ async function handler( res.status(200).json({ models: anthropic_models }); return; + case "textsynth": + if (chat) { + res.status(200).json({ + models: [ + // { id: "mistral_7B_instruct" }, + // { id: "falcon_40B-chat" }, + ], + }); + return; + } + res.status(200).json({ + models: [ + { id: "mistral_7B" }, + { id: "mistral_7B_instruct" }, + { id: "falcon_7B" }, + { id: "falcon_40B" }, + { id: "llama2_7B" }, + ], + }); + return; + default: res.status(404).json({ error: "Provider not found" }); return; diff --git a/front/pages/w/[wId]/a/index.tsx b/front/pages/w/[wId]/a/index.tsx index 202ab3dc18b6..4f61da4f4ce0 100644 --- a/front/pages/w/[wId]/a/index.tsx +++ b/front/pages/w/[wId]/a/index.tsx @@ -20,6 +20,7 @@ import CohereSetup from "@app/components/providers/CohereSetup"; import OpenAISetup from "@app/components/providers/OpenAISetup"; import SerpAPISetup from "@app/components/providers/SerpAPISetup"; import SerperSetup from "@app/components/providers/SerperSetup"; +import TextSynthSetup from "@app/components/providers/TextSynthSetup"; import AppLayout from "@app/components/sparkle/AppLayout"; import { subNavigationAdmin } from "@app/components/sparkle/navigation"; import { getApps } from "@app/lib/api/app"; @@ -203,6 +204,7 @@ export function Providers({ owner }: { owner: WorkspaceType }) { const [ai21Open, setAI21Open] = useState(false); const [azureOpenAIOpen, setAzureOpenAIOpen] = useState(false); const [anthropicOpen, setAnthropicOpen] = useState(false); + const [textSynthOpen, setTextSynthOpen] = useState(false); const [serpapiOpen, setSerpapiOpen] = useState(false); const [serperOpen, setSerperOpen] = useState(false); const [browserlessapiOpen, setBrowserlessapiOpen] = useState(false); @@ -255,6 +257,13 @@ export function Providers({ owner }: { owner: WorkspaceType }) { enabled={configs["anthropic"] ? true : false} config={configs["anthropic"] ? configs["anthropic"] : null} /> +