From 55c1ed3b1ec58e16fadb84c80a1c76968e2dc0ed Mon Sep 17 00:00:00 2001 From: Henry Fontanier Date: Mon, 9 Dec 2024 11:04:11 +0100 Subject: [PATCH] tmp --- core/src/lib.rs | 1 + core/src/providers/openai.rs | 6 + core/src/providers/provider.rs | 6 + core/src/providers/togetherai.rs | 353 ++++++++++++++++++ .../components/providers/TogetherAISetup.tsx | 199 ++++++++++ front/lib/providers.ts | 9 + .../api/w/[wId]/providers/[pId]/check.ts | 16 + .../api/w/[wId]/providers/[pId]/models.ts | 9 + front/pages/w/[wId]/developers/providers.tsx | 12 + types/src/front/lib/api/credentials.ts | 3 + types/src/front/provider.ts | 1 + 11 files changed, 615 insertions(+) create mode 100644 core/src/providers/togetherai.rs create mode 100644 front/components/providers/TogetherAISetup.tsx diff --git a/core/src/lib.rs b/core/src/lib.rs index 6a1b248d866fb..f5d867d6a3ee2 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -51,6 +51,7 @@ pub mod providers { } pub mod anthropic; pub mod google_ai_studio; + pub mod togetherai; } pub mod http { pub mod request; diff --git a/core/src/providers/openai.rs b/core/src/providers/openai.rs index 107e26d3f8fe9..f2bf3f3ce017f 100644 --- a/core/src/providers/openai.rs +++ b/core/src/providers/openai.rs @@ -1482,6 +1482,9 @@ pub async fn chat_completion( let req = req.json(&body); + println!("\n\nREQBODY: {:?}\n\n", body); + println!("\n\nURI: {:?}\n\n", uri); + let res = match timeout(Duration::new(180, 0), req.send()).await { Ok(Ok(res)) => res, Ok(Err(e)) => Err(e)?, @@ -1489,6 +1492,7 @@ pub async fn chat_completion( }; let res_headers = res.headers(); + println!("\n\nRESHEADERS: {:?}\n\n", res_headers); let request_id = match res_headers.get("x-request-id") { Some(request_id) => Some(request_id.to_str()?.to_string()), None => None, @@ -1500,6 +1504,8 @@ pub async fn chat_completion( Err(_) => Err(anyhow!("Timeout reading response from OpenAI after 180s"))?, }; + println!("RESBODY: {:?}", body); + let mut b: Vec = vec![]; body.reader().read_to_end(&mut b)?; let c: &[u8] = &b; diff --git a/core/src/providers/provider.rs b/core/src/providers/provider.rs index 9bd2e88dfcdc0..770810883605b 100644 --- a/core/src/providers/provider.rs +++ b/core/src/providers/provider.rs @@ -15,6 +15,8 @@ use std::fmt; use std::str::FromStr; use std::time::Duration; +use super::togetherai::TogetherAIProvider; + #[derive(Debug, Clone, Copy, Serialize, PartialEq, ValueEnum, Deserialize)] #[serde(rename_all = "lowercase")] #[clap(rename_all = "lowercase")] @@ -26,6 +28,7 @@ pub enum ProviderID { Mistral, #[serde(rename = "google_ai_studio")] GoogleAiStudio, + TogetherAI, } impl fmt::Display for ProviderID { @@ -36,6 +39,7 @@ impl fmt::Display for ProviderID { ProviderID::Anthropic => write!(f, "anthropic"), ProviderID::Mistral => write!(f, "mistral"), ProviderID::GoogleAiStudio => write!(f, "google_ai_studio"), + ProviderID::TogetherAI => write!(f, "togetherai"), } } } @@ -49,6 +53,7 @@ impl FromStr for ProviderID { "anthropic" => Ok(ProviderID::Anthropic), "mistral" => Ok(ProviderID::Mistral), "google_ai_studio" => Ok(ProviderID::GoogleAiStudio), + "togetherai" => Ok(ProviderID::TogetherAI), _ => Err(ParseError::with_message( "Unknown provider ID \ (possible values: openai, azure_openai, anthropic, mistral, google_ai_studio)", @@ -151,5 +156,6 @@ pub fn provider(t: ProviderID) -> Box { ProviderID::GoogleAiStudio => Box::new(GoogleAiStudioProvider::new()), ProviderID::Mistral => Box::new(MistralProvider::new()), ProviderID::OpenAI => Box::new(OpenAIProvider::new()), + ProviderID::TogetherAI => Box::new(TogetherAIProvider::new()), } } diff --git a/core/src/providers/togetherai.rs b/core/src/providers/togetherai.rs new file mode 100644 index 0000000000000..ab0daa7bf85b3 --- /dev/null +++ b/core/src/providers/togetherai.rs @@ -0,0 +1,353 @@ +use crate::providers::chat_messages::{AssistantChatMessage, ChatMessage}; +use crate::providers::embedder::Embedder; +use crate::providers::llm::ChatFunction; +use crate::providers::llm::{LLMChatGeneration, LLMGeneration, LLMTokenUsage, LLM}; +use crate::providers::openai::{ + chat_completion, streamed_chat_completion, to_openai_messages, OpenAITool, OpenAIToolChoice, +}; +use crate::providers::provider::{Provider, ProviderID}; +use crate::providers::tiktoken::tiktoken::{batch_tokenize_async, o200k_base_singleton, CoreBPE}; +use crate::providers::tiktoken::tiktoken::{decode_async, encode_async}; +use crate::run::Credentials; +use crate::utils; + +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use hyper::StatusCode; +use hyper::Uri; +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::str::FromStr; +use std::sync::Arc; +use tokio::sync::mpsc::UnboundedSender; + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct InnerError { + pub message: String, + #[serde(alias = "type")] + pub _type: String, + pub param: Option, + pub internal_message: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct TogetherAIError { + pub error: InnerError, +} + +impl TogetherAIError { + pub fn message(&self) -> String { + match self.error.internal_message { + Some(ref msg) => format!( + "TogetherAIError: [{}] {} internal_message={}", + self.error._type, self.error.message, msg, + ), + None => format!( + "TogetherAIError: [{}] {}", + self.error._type, self.error.message, + ), + } + } + + pub fn retryable(&self) -> bool { + match self.error._type.as_str() { + "requests" => true, + "server_error" => match &self.error.internal_message { + Some(message) if message.contains("retry") => true, + _ => false, + }, + _ => false, + } + } + + pub fn retryable_streamed(&self, status: StatusCode) -> bool { + if status == StatusCode::TOO_MANY_REQUESTS { + return true; + } + if status.is_server_error() { + return true; + } + match self.error._type.as_str() { + "server_error" => match self.error.internal_message { + Some(_) => true, + None => false, + }, + _ => false, + } + } +} + +pub struct TogetherAILLM { + id: String, + api_key: Option, +} + +impl TogetherAILLM { + pub fn new(id: String) -> Self { + TogetherAILLM { id, api_key: None } + } + + fn chat_uri(&self) -> Result { + Ok(format!("https://api.together.xyz/v1/chat/completions",).parse::()?) + } + + fn tokenizer(&self) -> Arc> { + // TODO(@fontanierh): TBD + o200k_base_singleton() + } + + pub fn togetherai_context_size(_model_id: &str) -> usize { + // TODO(@fontanierh): TBD + 131072 + } +} + +#[async_trait] +impl LLM for TogetherAILLM { + fn id(&self) -> String { + self.id.clone() + } + + async fn initialize(&mut self, credentials: Credentials) -> Result<()> { + match credentials.get("TOGETHERAI_API_KEY") { + Some(api_key) => { + self.api_key = Some(api_key.clone()); + } + None => { + match tokio::task::spawn_blocking(|| std::env::var("TOGETHERAI_API_KEY")).await? { + Ok(key) => { + self.api_key = Some(key); + } + Err(_) => Err(anyhow!( + "Credentials or environment variable `TOGETHERAI_API_KEY` is not set." + ))?, + } + } + } + Ok(()) + } + + fn context_size(&self) -> usize { + Self::togetherai_context_size(self.id.as_str()) + } + + async fn encode(&self, text: &str) -> Result> { + encode_async(self.tokenizer(), text).await + } + + async fn decode(&self, tokens: Vec) -> Result { + decode_async(self.tokenizer(), tokens).await + } + + async fn tokenize(&self, texts: Vec) -> Result>> { + batch_tokenize_async(self.tokenizer(), texts).await + } + + async fn generate( + &self, + _prompt: &str, + mut _max_tokens: Option, + _temperature: f32, + _n: usize, + _stop: &Vec, + _frequency_penalty: Option, + _presence_penalty: Option, + _top_p: Option, + _top_logprobs: Option, + _extras: Option, + _event_sender: Option>, + ) -> Result { + Err(anyhow!("Not implemented.")) + } + + // API is openai-compatible. + async fn chat( + &self, + messages: &Vec, + functions: &Vec, + function_call: Option, + temperature: f32, + top_p: Option, + n: usize, + stop: &Vec, + mut max_tokens: Option, + presence_penalty: Option, + frequency_penalty: Option, + _extras: Option, + event_sender: Option>, + ) -> Result { + if let Some(m) = max_tokens { + if m == -1 { + max_tokens = None; + } + } + + let tool_choice = match function_call.as_ref() { + Some(fc) => Some(OpenAIToolChoice::from_str(fc)?), + None => None, + }; + + let tools = functions + .iter() + .map(OpenAITool::try_from) + .collect::, _>>()?; + + let openai_messages = to_openai_messages(messages, &self.id)?; + + let is_streaming = event_sender.is_some(); + + let (c, request_id) = if is_streaming { + streamed_chat_completion( + self.chat_uri()?, + self.api_key.clone().unwrap(), + None, + Some(self.id.clone()), + &openai_messages, + tools, + tool_choice, + temperature, + match top_p { + Some(t) => t, + None => 1.0, + }, + n, + stop, + max_tokens, + match presence_penalty { + Some(p) => p, + None => 0.0, + }, + match frequency_penalty { + Some(f) => f, + None => 0.0, + }, + None, + None, + event_sender.clone(), + ) + .await? + } else { + match chat_completion( + self.chat_uri()?, + self.api_key.clone().unwrap(), + None, + Some(self.id.clone()), + &openai_messages, + tools, + tool_choice, + temperature, + match top_p { + Some(t) => t, + None => 1.0, + }, + n, + stop, + max_tokens, + match presence_penalty { + Some(p) => p, + None => 0.0, + }, + match frequency_penalty { + Some(f) => f, + None => 0.0, + }, + None, + None, + ) + .await + { + Ok(c) => c, + Err(e) => { + println!("\n\nTOGETHERAI ERROR: {:?}\n\n", e); + return Err(e); + } + } + }; + + assert!(c.choices.len() > 0); + + Ok(LLMChatGeneration { + created: utils::now(), + provider: ProviderID::OpenAI.to_string(), + model: self.id.clone(), + completions: c + .choices + .iter() + .map(|c| AssistantChatMessage::try_from(&c.message)) + .collect::>>()?, + usage: c.usage.map(|usage| LLMTokenUsage { + prompt_tokens: usage.prompt_tokens, + completion_tokens: usage.completion_tokens.unwrap_or(0), + }), + provider_request_id: request_id, + }) + } +} + +pub struct TogetherAIProvider {} + +impl TogetherAIProvider { + pub fn new() -> Self { + TogetherAIProvider {} + } +} + +#[async_trait] +impl Provider for TogetherAIProvider { + fn id(&self) -> ProviderID { + ProviderID::TogetherAI + } + + fn setup(&self) -> Result<()> { + utils::info("Setting up TogetherAI:"); + utils::info(""); + utils::info( + "To use TogetherAI's models, you must set the environment variable `TOGETHERAI_API_KEY`.", + ); + utils::info("Your API key can be found at `https://platform.openai.com/account/api-keys`."); + utils::info(""); + utils::info("Once ready you can check your setup with `dust provider test togetherai`"); + + Ok(()) + } + + async fn test(&self) -> Result<()> { + if !utils::confirm( + "You are about to make a request for 1 token to `meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo` on the TogetherAI API.", + )? { + Err(anyhow!("User aborted OpenAI test."))?; + } + + let mut llm = self.llm(String::from("meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo")); + llm.initialize(Credentials::new()).await?; + + let _ = llm + .generate( + "Hello 😊", + Some(1), + 0.7, + 1, + &vec![], + None, + None, + None, + None, + None, + None, + ) + .await?; + + utils::done("Test successfully completed! TogetherAI is ready to use."); + + Ok(()) + } + + fn llm(&self, id: String) -> Box { + Box::new(TogetherAILLM::new(id)) + } + + fn embedder(&self, _id: String) -> Box { + unimplemented!() + } +} diff --git a/front/components/providers/TogetherAISetup.tsx b/front/components/providers/TogetherAISetup.tsx new file mode 100644 index 0000000000000..7f71e56d31e20 --- /dev/null +++ b/front/components/providers/TogetherAISetup.tsx @@ -0,0 +1,199 @@ +import { Button } from "@dust-tt/sparkle"; +import type { WorkspaceType } from "@dust-tt/types"; +import { Dialog, Transition } from "@headlessui/react"; +import { Fragment, useEffect, useState } from "react"; +import { useSWRConfig } from "swr"; + +import { checkProvider } from "@app/lib/providers"; + +export default function TogetherAISetup({ + owner, + open, + setOpen, + config, + enabled, +}: { + owner: WorkspaceType; + open: boolean; + setOpen: (open: boolean) => void; + config: { [key: string]: string }; + enabled: boolean; +}) { + const { mutate } = useSWRConfig(); + + const [apiKey, setApiKey] = useState(config ? config.api_key : ""); + const [testSuccessful, setTestSuccessful] = useState(false); + const [testRunning, setTestRunning] = useState(false); + const [testError, setTestError] = useState(""); + const [enableRunning, setEnableRunning] = useState(false); + + useEffect(() => { + if (config && config.api_key.length > 0 && apiKey.length == 0) { + setApiKey(config.api_key); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [config]); + + const runTest = async () => { + setTestRunning(true); + setTestError(""); + const check = await checkProvider(owner, "togetherai", { + api_key: apiKey, + }); + + if (!check.ok) { + setTestError(check.error); + setTestSuccessful(false); + setTestRunning(false); + } else { + setTestError(""); + setTestSuccessful(true); + setTestRunning(false); + } + }; + + const handleEnable = async () => { + setEnableRunning(true); + const res = await fetch(`/api/w/${owner.sId}/providers/togetherai`, { + headers: { + "Content-Type": "application/json", + }, + method: "POST", + body: JSON.stringify({ + config: JSON.stringify({ + api_key: apiKey, + }), + }), + }); + await res.json(); + setEnableRunning(false); + setOpen(false); + await mutate(`/api/w/${owner.sId}/providers`); + }; + + const handleDisable = async () => { + const res = await fetch(`/api/w/${owner.sId}/providers/togetherai`, { + method: "DELETE", + }); + await res.json(); + setOpen(false); + await mutate(`/api/w/${owner.sId}/providers`); + }; + + return ( + + setOpen(false)}> + +
+ + +
+
+ + +
+
+ + Setup TogetherAI + +
+

+ To use TogetherAI models you must provide your API key. +

+

+ We'll never use your API key for anything other than to + run your apps. +

+
+
+ { + setApiKey(e.target.value); + setTestSuccessful(false); + }} + /> +
+
+
+
+ {testError.length > 0 ? ( + Error: {testError} + ) : testSuccessful ? ( + + Test succeeded! You can enable TogetherAI. + + ) : ( +   + )} +
+
+ {enabled ? ( +
handleDisable()} + > + Disable +
+ ) : ( + <> + )} +
+
+
+
+ {testSuccessful ? ( +
+
+
+
+
+
+
+
+ ); +} diff --git a/front/lib/providers.ts b/front/lib/providers.ts index 2e16c0e0bb67b..c22c2fbc86085 100644 --- a/front/lib/providers.ts +++ b/front/lib/providers.ts @@ -54,6 +54,14 @@ export const modelProviders: ModelProvider[] = [ chat: true, embed: false, }, + { + providerId: "togetherai", + name: "TogetherAI", + built: true, + enabled: false, + chat: true, + embed: false, + }, ]; export const APP_MODEL_PROVIDER_IDS: string[] = [ @@ -61,6 +69,7 @@ export const APP_MODEL_PROVIDER_IDS: string[] = [ "anthropic", "mistral", "google_ai_studio", + "togetherai", "azure_openai", ] as const; diff --git a/front/pages/api/w/[wId]/providers/[pId]/check.ts b/front/pages/api/w/[wId]/providers/[pId]/check.ts index c867cca0aba4b..73b6d6696b267 100644 --- a/front/pages/api/w/[wId]/providers/[pId]/check.ts +++ b/front/pages/api/w/[wId]/providers/[pId]/check.ts @@ -242,6 +242,22 @@ async function handler( await rGoogleAIStudio.json(); return res.status(200).json({ ok: true }); + case "togetherai": + const tModelsRes = await fetch("https://api.together.xyz/v1/models", { + method: "GET", + headers: { + Authorization: `Bearer ${config.api_key}`, + }, + }); + if (!tModelsRes.ok) { + const err = await tModelsRes.json(); + res.status(400).json({ ok: false, error: err.error.code }); + } else { + await tModelsRes.json(); + res.status(200).json({ ok: true }); + } + return; + default: return apiError(req, res, { status_code: 404, diff --git a/front/pages/api/w/[wId]/providers/[pId]/models.ts b/front/pages/api/w/[wId]/providers/[pId]/models.ts index fcaabd4809e70..5a9df8cb21cc7 100644 --- a/front/pages/api/w/[wId]/providers/[pId]/models.ts +++ b/front/pages/api/w/[wId]/providers/[pId]/models.ts @@ -242,6 +242,15 @@ async function handler( ], }); + case "togetherai": + if (embed) { + res.status(200).json({ models: [] }); + return; + } + return res.status(200).json({ + models: [{ id: "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo" }], + }); + default: return apiError(req, res, { status_code: 404, diff --git a/front/pages/w/[wId]/developers/providers.tsx b/front/pages/w/[wId]/developers/providers.tsx index 80f6ad8f9fe1f..f686392233352 100644 --- a/front/pages/w/[wId]/developers/providers.tsx +++ b/front/pages/w/[wId]/developers/providers.tsx @@ -14,6 +14,7 @@ import MistralAISetup from "@app/components/providers/MistralAISetup"; import OpenAISetup from "@app/components/providers/OpenAISetup"; import SerpAPISetup from "@app/components/providers/SerpAPISetup"; import SerperSetup from "@app/components/providers/SerperSetup"; +import TogetherAISetup from "@app/components/providers/TogetherAISetup"; import AppLayout from "@app/components/sparkle/AppLayout"; import { withDefaultUserAuthRequirements } from "@app/lib/iam/session"; import { @@ -56,6 +57,7 @@ export function Providers({ owner }: { owner: WorkspaceType }) { const [serpapiOpen, setSerpapiOpen] = useState(false); const [serperOpen, setSerperOpen] = useState(false); const [browserlessapiOpen, setBrowserlessapiOpen] = useState(false); + const [togetherAiOpen, setTogetherAiOpen] = useState(false); const { providers, isProvidersLoading, isProvidersError } = useProviders({ owner, @@ -130,6 +132,13 @@ export function Providers({ owner }: { owner: WorkspaceType }) { enabled={!!configs["google_ai_studio"]} config={configs["google_ai_studio"] ?? null} /> +