diff --git a/core/src/lib.rs b/core/src/lib.rs index 6a1b248d866f..f5d867d6a3ee 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -51,6 +51,7 @@ pub mod providers { } pub mod anthropic; pub mod google_ai_studio; + pub mod togetherai; } pub mod http { pub mod request; diff --git a/core/src/providers/openai.rs b/core/src/providers/openai.rs index 107e26d3f8fe..6444e0136b68 100644 --- a/core/src/providers/openai.rs +++ b/core/src/providers/openai.rs @@ -287,12 +287,16 @@ pub enum OpenAIContentBlock { } #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct OpenAIContentBlockVec(Vec); +#[serde(untagged)] +pub enum OpenAIChatMessageContent { + Structured(Vec), + String(String), +} #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] pub struct OpenAIChatMessage { pub role: OpenAIChatMessageRole, - pub content: Option, + pub content: Option, #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, #[serde(skip_serializing_if = "Option::is_none")] @@ -379,12 +383,12 @@ impl TryFrom<&OpenAICompletionChatMessage> for AssistantChatMessage { } } -impl TryFrom<&ContentBlock> for OpenAIContentBlockVec { +impl TryFrom<&ContentBlock> for OpenAIChatMessageContent { type Error = anyhow::Error; fn try_from(cm: &ContentBlock) -> Result { match cm { - ContentBlock::Text(t) => Ok(OpenAIContentBlockVec(vec![ + ContentBlock::Text(t) => Ok(OpenAIChatMessageContent::Structured(vec![ OpenAIContentBlock::TextContent(OpenAITextContent { r#type: OpenAITextContentType::Text, text: t.clone(), @@ -411,17 +415,17 @@ impl TryFrom<&ContentBlock> for OpenAIContentBlockVec { }) .collect::>>()?; - Ok(OpenAIContentBlockVec(content)) + Ok(OpenAIChatMessageContent::Structured(content)) } } } } -impl TryFrom<&String> for OpenAIContentBlockVec { +impl TryFrom<&String> for OpenAIChatMessageContent { type Error = anyhow::Error; fn try_from(t: &String) -> Result { - Ok(OpenAIContentBlockVec(vec![ + Ok(OpenAIChatMessageContent::Structured(vec![ OpenAIContentBlock::TextContent(OpenAITextContent { r#type: OpenAITextContentType::Text, text: t.clone(), @@ -437,7 +441,7 @@ impl TryFrom<&ChatMessage> for OpenAIChatMessage { match cm { ChatMessage::Assistant(assistant_msg) => Ok(OpenAIChatMessage { content: match &assistant_msg.content { - Some(c) => Some(OpenAIContentBlockVec::try_from(c)?), + Some(c) => Some(OpenAIChatMessageContent::try_from(c)?), None => None, }, name: assistant_msg.name.clone(), @@ -453,21 +457,21 @@ impl TryFrom<&ChatMessage> for OpenAIChatMessage { tool_call_id: None, }), ChatMessage::Function(function_msg) => Ok(OpenAIChatMessage { - content: Some(OpenAIContentBlockVec::try_from(&function_msg.content)?), + content: Some(OpenAIChatMessageContent::try_from(&function_msg.content)?), name: None, role: OpenAIChatMessageRole::Tool, tool_calls: None, tool_call_id: Some(function_msg.function_call_id.clone()), }), ChatMessage::System(system_msg) => Ok(OpenAIChatMessage { - content: Some(OpenAIContentBlockVec::try_from(&system_msg.content)?), + content: Some(OpenAIChatMessageContent::try_from(&system_msg.content)?), name: None, role: OpenAIChatMessageRole::from(&system_msg.role), tool_calls: None, tool_call_id: None, }), ChatMessage::User(user_msg) => Ok(OpenAIChatMessage { - content: Some(OpenAIContentBlockVec::try_from(&user_msg.content)?), + content: Some(OpenAIChatMessageContent::try_from(&user_msg.content)?), name: user_msg.name.clone(), role: OpenAIChatMessageRole::from(&user_msg.role), tool_calls: None, diff --git a/core/src/providers/provider.rs b/core/src/providers/provider.rs index 9bd2e88dfcdc..770810883605 100644 --- a/core/src/providers/provider.rs +++ b/core/src/providers/provider.rs @@ -15,6 +15,8 @@ use std::fmt; use std::str::FromStr; use std::time::Duration; +use super::togetherai::TogetherAIProvider; + #[derive(Debug, Clone, Copy, Serialize, PartialEq, ValueEnum, Deserialize)] #[serde(rename_all = "lowercase")] #[clap(rename_all = "lowercase")] @@ -26,6 +28,7 @@ pub enum ProviderID { Mistral, #[serde(rename = "google_ai_studio")] GoogleAiStudio, + TogetherAI, } impl fmt::Display for ProviderID { @@ -36,6 +39,7 @@ impl fmt::Display for ProviderID { ProviderID::Anthropic => write!(f, "anthropic"), ProviderID::Mistral => write!(f, "mistral"), ProviderID::GoogleAiStudio => write!(f, "google_ai_studio"), + ProviderID::TogetherAI => write!(f, "togetherai"), } } } @@ -49,6 +53,7 @@ impl FromStr for ProviderID { "anthropic" => Ok(ProviderID::Anthropic), "mistral" => Ok(ProviderID::Mistral), "google_ai_studio" => Ok(ProviderID::GoogleAiStudio), + "togetherai" => Ok(ProviderID::TogetherAI), _ => Err(ParseError::with_message( "Unknown provider ID \ (possible values: openai, azure_openai, anthropic, mistral, google_ai_studio)", @@ -151,5 +156,6 @@ pub fn provider(t: ProviderID) -> Box { ProviderID::GoogleAiStudio => Box::new(GoogleAiStudioProvider::new()), ProviderID::Mistral => Box::new(MistralProvider::new()), ProviderID::OpenAI => Box::new(OpenAIProvider::new()), + ProviderID::TogetherAI => Box::new(TogetherAIProvider::new()), } } diff --git a/core/src/providers/togetherai.rs b/core/src/providers/togetherai.rs new file mode 100644 index 000000000000..806f5b1bed34 --- /dev/null +++ b/core/src/providers/togetherai.rs @@ -0,0 +1,317 @@ +use crate::providers::chat_messages::{AssistantChatMessage, ChatMessage}; +use crate::providers::embedder::Embedder; +use crate::providers::llm::ChatFunction; +use crate::providers::llm::{LLMChatGeneration, LLMGeneration, LLMTokenUsage, LLM}; +use crate::providers::openai::{ + chat_completion, streamed_chat_completion, to_openai_messages, OpenAIChatMessage, + OpenAIChatMessageContent, OpenAIContentBlock, OpenAITextContent, OpenAITextContentType, + OpenAITool, OpenAIToolChoice, +}; +use crate::providers::provider::{Provider, ProviderID}; +use crate::providers::tiktoken::tiktoken::{batch_tokenize_async, o200k_base_singleton, CoreBPE}; +use crate::providers::tiktoken::tiktoken::{decode_async, encode_async}; +use crate::run::Credentials; +use crate::utils; + +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use hyper::Uri; +use parking_lot::RwLock; +use serde_json::Value; +use std::str::FromStr; +use std::sync::Arc; +use tokio::sync::mpsc::UnboundedSender; + +pub struct TogetherAILLM { + id: String, + api_key: Option, +} + +impl TogetherAILLM { + pub fn new(id: String) -> Self { + TogetherAILLM { id, api_key: None } + } + + fn chat_uri(&self) -> Result { + Ok(format!("https://api.together.xyz/v1/chat/completions",).parse::()?) + } + + fn tokenizer(&self) -> Arc> { + // TODO(@fontanierh): TBD + o200k_base_singleton() + } + + pub fn togetherai_context_size(_model_id: &str) -> usize { + // TODO(@fontanierh): TBD + 131072 + } +} + +#[async_trait] +impl LLM for TogetherAILLM { + fn id(&self) -> String { + self.id.clone() + } + + async fn initialize(&mut self, credentials: Credentials) -> Result<()> { + match credentials.get("TOGETHERAI_API_KEY") { + Some(api_key) => { + self.api_key = Some(api_key.clone()); + } + None => { + match tokio::task::spawn_blocking(|| std::env::var("TOGETHERAI_API_KEY")).await? { + Ok(key) => { + self.api_key = Some(key); + } + Err(_) => Err(anyhow!( + "Credentials or environment variable `TOGETHERAI_API_KEY` is not set." + ))?, + } + } + } + Ok(()) + } + + fn context_size(&self) -> usize { + Self::togetherai_context_size(self.id.as_str()) + } + + async fn encode(&self, text: &str) -> Result> { + encode_async(self.tokenizer(), text).await + } + + async fn decode(&self, tokens: Vec) -> Result { + decode_async(self.tokenizer(), tokens).await + } + + async fn tokenize(&self, texts: Vec) -> Result>> { + batch_tokenize_async(self.tokenizer(), texts).await + } + + async fn generate( + &self, + _prompt: &str, + mut _max_tokens: Option, + _temperature: f32, + _n: usize, + _stop: &Vec, + _frequency_penalty: Option, + _presence_penalty: Option, + _top_p: Option, + _top_logprobs: Option, + _extras: Option, + _event_sender: Option>, + ) -> Result { + Err(anyhow!("Not implemented.")) + } + + // API is openai-compatible. + async fn chat( + &self, + messages: &Vec, + functions: &Vec, + function_call: Option, + temperature: f32, + top_p: Option, + n: usize, + stop: &Vec, + mut max_tokens: Option, + presence_penalty: Option, + frequency_penalty: Option, + _extras: Option, + event_sender: Option>, + ) -> Result { + if let Some(m) = max_tokens { + if m == -1 { + max_tokens = None; + } + } + + let tool_choice = match function_call.as_ref() { + Some(fc) => Some(OpenAIToolChoice::from_str(fc)?), + None => None, + }; + + let tools = functions + .iter() + .map(OpenAITool::try_from) + .collect::, _>>()?; + + // TogetherAI doesn't work with the new chat message content format. + // We have to modify the messages contents to use the "String" format. + let openai_messages = to_openai_messages(messages, &self.id)? + .into_iter() + .filter_map(|m| match m.content { + None => Some(m), + Some(OpenAIChatMessageContent::String(_)) => Some(m), + Some(OpenAIChatMessageContent::Structured(contents)) => { + // Find the first text content, and use it to make a string content. + let content = contents.into_iter().find_map(|c| match c { + OpenAIContentBlock::TextContent(OpenAITextContent { + r#type: OpenAITextContentType::Text, + text, + .. + }) => Some(OpenAIChatMessageContent::String(text)), + _ => None, + }); + + Some(OpenAIChatMessage { + role: m.role, + name: m.name, + tool_call_id: m.tool_call_id, + tool_calls: m.tool_calls, + content, + }) + } + }) + .collect::>(); + + let is_streaming = event_sender.is_some(); + + let (c, request_id) = if is_streaming { + streamed_chat_completion( + self.chat_uri()?, + self.api_key.clone().unwrap(), + None, + Some(self.id.clone()), + &openai_messages, + tools, + tool_choice, + temperature, + match top_p { + Some(t) => t, + None => 1.0, + }, + n, + stop, + max_tokens, + match presence_penalty { + Some(p) => p, + None => 0.0, + }, + match frequency_penalty { + Some(f) => f, + None => 0.0, + }, + None, + None, + event_sender.clone(), + ) + .await? + } else { + chat_completion( + self.chat_uri()?, + self.api_key.clone().unwrap(), + None, + Some(self.id.clone()), + &openai_messages, + tools, + tool_choice, + temperature, + match top_p { + Some(t) => t, + None => 1.0, + }, + n, + stop, + max_tokens, + match presence_penalty { + Some(p) => p, + None => 0.0, + }, + match frequency_penalty { + Some(f) => f, + None => 0.0, + }, + None, + None, + ) + .await? + }; + + assert!(c.choices.len() > 0); + + Ok(LLMChatGeneration { + created: utils::now(), + provider: ProviderID::OpenAI.to_string(), + model: self.id.clone(), + completions: c + .choices + .iter() + .map(|c| AssistantChatMessage::try_from(&c.message)) + .collect::>>()?, + usage: c.usage.map(|usage| LLMTokenUsage { + prompt_tokens: usage.prompt_tokens, + completion_tokens: usage.completion_tokens.unwrap_or(0), + }), + provider_request_id: request_id, + }) + } +} + +pub struct TogetherAIProvider {} + +impl TogetherAIProvider { + pub fn new() -> Self { + TogetherAIProvider {} + } +} + +#[async_trait] +impl Provider for TogetherAIProvider { + fn id(&self) -> ProviderID { + ProviderID::TogetherAI + } + + fn setup(&self) -> Result<()> { + utils::info("Setting up TogetherAI:"); + utils::info(""); + utils::info( + "To use TogetherAI's models, you must set the environment variable `TOGETHERAI_API_KEY`.", + ); + utils::info("Your API key can be found at `https://platform.openai.com/account/api-keys`."); + utils::info(""); + utils::info("Once ready you can check your setup with `dust provider test togetherai`"); + + Ok(()) + } + + async fn test(&self) -> Result<()> { + if !utils::confirm( + "You are about to make a request for 1 token to `meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo` on the TogetherAI API.", + )? { + Err(anyhow!("User aborted OpenAI test."))?; + } + + let mut llm = self.llm(String::from("meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo")); + llm.initialize(Credentials::new()).await?; + + let _ = llm + .generate( + "Hello 😊", + Some(1), + 0.7, + 1, + &vec![], + None, + None, + None, + None, + None, + None, + ) + .await?; + + utils::done("Test successfully completed! TogetherAI is ready to use."); + + Ok(()) + } + + fn llm(&self, id: String) -> Box { + Box::new(TogetherAILLM::new(id)) + } + + fn embedder(&self, _id: String) -> Box { + unimplemented!() + } +} diff --git a/front/components/providers/TogetherAISetup.tsx b/front/components/providers/TogetherAISetup.tsx new file mode 100644 index 000000000000..d25785738c08 --- /dev/null +++ b/front/components/providers/TogetherAISetup.tsx @@ -0,0 +1,199 @@ +import { Button } from "@dust-tt/sparkle"; +import type { WorkspaceType } from "@dust-tt/types"; +import { Dialog, Transition } from "@headlessui/react"; +import { Fragment, useEffect, useState } from "react"; +import { useSWRConfig } from "swr"; + +import { checkProvider } from "@app/lib/providers"; + +export default function TogetherAISetup({ + owner, + open, + setOpen, + config, + enabled, +}: { + owner: WorkspaceType; + open: boolean; + setOpen: (open: boolean) => void; + config: { [key: string]: string }; + enabled: boolean; +}) { + const { mutate } = useSWRConfig(); + + const [apiKey, setApiKey] = useState(config ? config.api_key : ""); + const [testSuccessful, setTestSuccessful] = useState(false); + const [testRunning, setTestRunning] = useState(false); + const [testError, setTestError] = useState(""); + const [enableRunning, setEnableRunning] = useState(false); + + useEffect(() => { + if (config && config.api_key.length > 0 && apiKey.length == 0) { + setApiKey(config.api_key); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [config]); + + const runTest = async () => { + setTestRunning(true); + setTestError(""); + const check = await checkProvider(owner, "togetherai", { + api_key: apiKey, + }); + + if (!check.ok) { + setTestError(check.error); + setTestSuccessful(false); + setTestRunning(false); + } else { + setTestError(""); + setTestSuccessful(true); + setTestRunning(false); + } + }; + + const handleEnable = async () => { + setEnableRunning(true); + const res = await fetch(`/api/w/${owner.sId}/providers/togetherai`, { + headers: { + "Content-Type": "application/json", + }, + method: "POST", + body: JSON.stringify({ + config: JSON.stringify({ + api_key: apiKey, + }), + }), + }); + await res.json(); + setEnableRunning(false); + setOpen(false); + await mutate(`/api/w/${owner.sId}/providers`); + }; + + const handleDisable = async () => { + const res = await fetch(`/api/w/${owner.sId}/providers/togetherai`, { + method: "DELETE", + }); + await res.json(); + setOpen(false); + await mutate(`/api/w/${owner.sId}/providers`); + }; + + return ( + + setOpen(false)}> + +
+ + +
+
+ + +
+
+ + Setup TogetherAI + +
+

+ To use TogetherAI models you must provide your API key. +

+

+ We'll never use your API key for anything other than to + run your apps. +

+
+
+ { + setApiKey(e.target.value); + setTestSuccessful(false); + }} + /> +
+
+
+
+ {testError?.length > 0 ? ( + Error: {testError} + ) : testSuccessful ? ( + + Test succeeded! You can enable TogetherAI. + + ) : ( +   + )} +
+
+ {enabled ? ( +
handleDisable()} + > + Disable +
+ ) : ( + <> + )} +
+
+
+
+ {testSuccessful ? ( +
+
+
+
+
+
+
+
+ ); +} diff --git a/front/lib/providers.ts b/front/lib/providers.ts index 2e16c0e0bb67..c22c2fbc8608 100644 --- a/front/lib/providers.ts +++ b/front/lib/providers.ts @@ -54,6 +54,14 @@ export const modelProviders: ModelProvider[] = [ chat: true, embed: false, }, + { + providerId: "togetherai", + name: "TogetherAI", + built: true, + enabled: false, + chat: true, + embed: false, + }, ]; export const APP_MODEL_PROVIDER_IDS: string[] = [ @@ -61,6 +69,7 @@ export const APP_MODEL_PROVIDER_IDS: string[] = [ "anthropic", "mistral", "google_ai_studio", + "togetherai", "azure_openai", ] as const; diff --git a/front/pages/api/w/[wId]/providers/[pId]/check.ts b/front/pages/api/w/[wId]/providers/[pId]/check.ts index c867cca0aba4..b9f9cc293d71 100644 --- a/front/pages/api/w/[wId]/providers/[pId]/check.ts +++ b/front/pages/api/w/[wId]/providers/[pId]/check.ts @@ -242,6 +242,22 @@ async function handler( await rGoogleAIStudio.json(); return res.status(200).json({ ok: true }); + case "togetherai": + const tModelsRes = await fetch("https://api.together.xyz/v1/models", { + method: "GET", + headers: { + Authorization: `Bearer ${config.api_key}`, + }, + }); + if (!tModelsRes.ok) { + const err = await tModelsRes.json(); + res.status(400).json({ ok: false, error: err.error.message }); + } else { + await tModelsRes.json(); + res.status(200).json({ ok: true }); + } + return; + default: return apiError(req, res, { status_code: 404, diff --git a/front/pages/api/w/[wId]/providers/[pId]/models.ts b/front/pages/api/w/[wId]/providers/[pId]/models.ts index fcaabd4809e7..20bc0260320e 100644 --- a/front/pages/api/w/[wId]/providers/[pId]/models.ts +++ b/front/pages/api/w/[wId]/providers/[pId]/models.ts @@ -242,6 +242,22 @@ async function handler( ], }); + case "togetherai": + if (embed) { + res.status(200).json({ models: [] }); + return; + } + return res.status(200).json({ + models: [ + // llama + { id: "meta-llama/Llama-3.3-70B-Instruct-Turbo" }, + // qwen + { id: "Qwen/Qwen2.5-Coder-32B-Instruct" }, + { id: "Qwen/QwQ-32B-Preview" }, + { id: "Qwen/Qwen2-72B-Instruct" }, + ], + }); + default: return apiError(req, res, { status_code: 404, diff --git a/front/pages/w/[wId]/developers/providers.tsx b/front/pages/w/[wId]/developers/providers.tsx index 80f6ad8f9fe1..f68639223335 100644 --- a/front/pages/w/[wId]/developers/providers.tsx +++ b/front/pages/w/[wId]/developers/providers.tsx @@ -14,6 +14,7 @@ import MistralAISetup from "@app/components/providers/MistralAISetup"; import OpenAISetup from "@app/components/providers/OpenAISetup"; import SerpAPISetup from "@app/components/providers/SerpAPISetup"; import SerperSetup from "@app/components/providers/SerperSetup"; +import TogetherAISetup from "@app/components/providers/TogetherAISetup"; import AppLayout from "@app/components/sparkle/AppLayout"; import { withDefaultUserAuthRequirements } from "@app/lib/iam/session"; import { @@ -56,6 +57,7 @@ export function Providers({ owner }: { owner: WorkspaceType }) { const [serpapiOpen, setSerpapiOpen] = useState(false); const [serperOpen, setSerperOpen] = useState(false); const [browserlessapiOpen, setBrowserlessapiOpen] = useState(false); + const [togetherAiOpen, setTogetherAiOpen] = useState(false); const { providers, isProvidersLoading, isProvidersError } = useProviders({ owner, @@ -130,6 +132,13 @@ export function Providers({ owner }: { owner: WorkspaceType }) { enabled={!!configs["google_ai_studio"]} config={configs["google_ai_studio"] ?? null} /> +