Skip to content

Commit

Permalink
feat: add support for llama 3.3 via together AI (#9222)
Browse files Browse the repository at this point in the history
* feat: add support for llama 3.3 via together AI

* remove useless type

* add some Qwens too

---------

Co-authored-by: Henry Fontanier <[email protected]>
  • Loading branch information
fontanierh and Henry Fontanier authored Dec 11, 2024
1 parent 386fba0 commit c678d5f
Show file tree
Hide file tree
Showing 11 changed files with 595 additions and 11 deletions.
1 change: 1 addition & 0 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pub mod providers {
}
pub mod anthropic;
pub mod google_ai_studio;
pub mod togetherai;
}
pub mod http {
pub mod request;
Expand Down
26 changes: 15 additions & 11 deletions core/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,16 @@ pub enum OpenAIContentBlock {
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct OpenAIContentBlockVec(Vec<OpenAIContentBlock>);
#[serde(untagged)]
pub enum OpenAIChatMessageContent {
Structured(Vec<OpenAIContentBlock>),
String(String),
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct OpenAIChatMessage {
pub role: OpenAIChatMessageRole,
pub content: Option<OpenAIContentBlockVec>,
pub content: Option<OpenAIChatMessageContent>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
Expand Down Expand Up @@ -379,12 +383,12 @@ impl TryFrom<&OpenAICompletionChatMessage> for AssistantChatMessage {
}
}

impl TryFrom<&ContentBlock> for OpenAIContentBlockVec {
impl TryFrom<&ContentBlock> for OpenAIChatMessageContent {
type Error = anyhow::Error;

fn try_from(cm: &ContentBlock) -> Result<Self, Self::Error> {
match cm {
ContentBlock::Text(t) => Ok(OpenAIContentBlockVec(vec![
ContentBlock::Text(t) => Ok(OpenAIChatMessageContent::Structured(vec![
OpenAIContentBlock::TextContent(OpenAITextContent {
r#type: OpenAITextContentType::Text,
text: t.clone(),
Expand All @@ -411,17 +415,17 @@ impl TryFrom<&ContentBlock> for OpenAIContentBlockVec {
})
.collect::<Result<Vec<OpenAIContentBlock>>>()?;

Ok(OpenAIContentBlockVec(content))
Ok(OpenAIChatMessageContent::Structured(content))
}
}
}
}

impl TryFrom<&String> for OpenAIContentBlockVec {
impl TryFrom<&String> for OpenAIChatMessageContent {
type Error = anyhow::Error;

fn try_from(t: &String) -> Result<Self, Self::Error> {
Ok(OpenAIContentBlockVec(vec![
Ok(OpenAIChatMessageContent::Structured(vec![
OpenAIContentBlock::TextContent(OpenAITextContent {
r#type: OpenAITextContentType::Text,
text: t.clone(),
Expand All @@ -437,7 +441,7 @@ impl TryFrom<&ChatMessage> for OpenAIChatMessage {
match cm {
ChatMessage::Assistant(assistant_msg) => Ok(OpenAIChatMessage {
content: match &assistant_msg.content {
Some(c) => Some(OpenAIContentBlockVec::try_from(c)?),
Some(c) => Some(OpenAIChatMessageContent::try_from(c)?),
None => None,
},
name: assistant_msg.name.clone(),
Expand All @@ -453,21 +457,21 @@ impl TryFrom<&ChatMessage> for OpenAIChatMessage {
tool_call_id: None,
}),
ChatMessage::Function(function_msg) => Ok(OpenAIChatMessage {
content: Some(OpenAIContentBlockVec::try_from(&function_msg.content)?),
content: Some(OpenAIChatMessageContent::try_from(&function_msg.content)?),
name: None,
role: OpenAIChatMessageRole::Tool,
tool_calls: None,
tool_call_id: Some(function_msg.function_call_id.clone()),
}),
ChatMessage::System(system_msg) => Ok(OpenAIChatMessage {
content: Some(OpenAIContentBlockVec::try_from(&system_msg.content)?),
content: Some(OpenAIChatMessageContent::try_from(&system_msg.content)?),
name: None,
role: OpenAIChatMessageRole::from(&system_msg.role),
tool_calls: None,
tool_call_id: None,
}),
ChatMessage::User(user_msg) => Ok(OpenAIChatMessage {
content: Some(OpenAIContentBlockVec::try_from(&user_msg.content)?),
content: Some(OpenAIChatMessageContent::try_from(&user_msg.content)?),
name: user_msg.name.clone(),
role: OpenAIChatMessageRole::from(&user_msg.role),
tool_calls: None,
Expand Down
6 changes: 6 additions & 0 deletions core/src/providers/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ use std::fmt;
use std::str::FromStr;
use std::time::Duration;

use super::togetherai::TogetherAIProvider;

#[derive(Debug, Clone, Copy, Serialize, PartialEq, ValueEnum, Deserialize)]
#[serde(rename_all = "lowercase")]
#[clap(rename_all = "lowercase")]
Expand All @@ -26,6 +28,7 @@ pub enum ProviderID {
Mistral,
#[serde(rename = "google_ai_studio")]
GoogleAiStudio,
TogetherAI,
}

impl fmt::Display for ProviderID {
Expand All @@ -36,6 +39,7 @@ impl fmt::Display for ProviderID {
ProviderID::Anthropic => write!(f, "anthropic"),
ProviderID::Mistral => write!(f, "mistral"),
ProviderID::GoogleAiStudio => write!(f, "google_ai_studio"),
ProviderID::TogetherAI => write!(f, "togetherai"),
}
}
}
Expand All @@ -49,6 +53,7 @@ impl FromStr for ProviderID {
"anthropic" => Ok(ProviderID::Anthropic),
"mistral" => Ok(ProviderID::Mistral),
"google_ai_studio" => Ok(ProviderID::GoogleAiStudio),
"togetherai" => Ok(ProviderID::TogetherAI),
_ => Err(ParseError::with_message(
"Unknown provider ID \
(possible values: openai, azure_openai, anthropic, mistral, google_ai_studio)",
Expand Down Expand Up @@ -151,5 +156,6 @@ pub fn provider(t: ProviderID) -> Box<dyn Provider + Sync + Send> {
ProviderID::GoogleAiStudio => Box::new(GoogleAiStudioProvider::new()),
ProviderID::Mistral => Box::new(MistralProvider::new()),
ProviderID::OpenAI => Box::new(OpenAIProvider::new()),
ProviderID::TogetherAI => Box::new(TogetherAIProvider::new()),
}
}
Loading

0 comments on commit c678d5f

Please sign in to comment.