diff --git a/connectors/src/connectors/slack/bot.ts b/connectors/src/connectors/slack/bot.ts index 1cfd1adc3e9e..63d793c148cd 100644 --- a/connectors/src/connectors/slack/bot.ts +++ b/connectors/src/connectors/slack/bot.ts @@ -272,7 +272,8 @@ async function botAnswerMessage( } } // Extract all ~mentions. - const mentionCandidates = message.match(/~[a-zA-Z0-9_-]{1,20}/g) || []; + const mentionCandidates = + message.match(/(? 1) { diff --git a/core/src/providers/google_vertex_ai.rs b/core/src/providers/google_vertex_ai.rs index ea53e6aa38df..4c652c75cddc 100644 --- a/core/src/providers/google_vertex_ai.rs +++ b/core/src/providers/google_vertex_ai.rs @@ -20,24 +20,74 @@ use crate::{ use super::{ embedder::Embedder, - llm::{ChatFunction, ChatMessage, LLMChatGeneration, LLMGeneration, LLM}, + llm::{ChatFunction, ChatFunctionCall, ChatMessage, LLMChatGeneration, LLMGeneration, LLM}, provider::{Provider, ProviderID}, tiktoken::tiktoken::{ cl100k_base_singleton, decode_async, encode_async, tokenize_async, CoreBPE, }, }; +// Disabled for now as it requires using a "tools" API which we don't support yet. +pub const USE_FUNCTION_CALLING: bool = false; + #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] pub struct UsageMetadata { prompt_token_count: usize, - candidates_token_count: usize, + candidates_token_count: Option, total_token_count: usize, } #[derive(Serialize, Deserialize, Debug, Clone)] +pub struct VertexAiFunctionResponseContent { + name: String, + content: String, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct VertexAiFunctionResponse { + name: String, + response: VertexAiFunctionResponseContent, +} + +impl TryFrom<&ChatMessage> for VertexAiFunctionResponse { + type Error = anyhow::Error; + + fn try_from(m: &ChatMessage) -> Result { + let name = m.name.clone().unwrap_or_default(); + Ok(VertexAiFunctionResponse { + name: name.clone(), + response: VertexAiFunctionResponseContent { + name: name, + content: m.content.clone().unwrap_or_default(), + }, + }) + } +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct VertexAiFunctionCall { + name: String, + args: String, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] pub struct Part { - text: String, + text: Option, + function_call: Option, + function_response: Option, +} + +impl TryFrom<&ChatFunctionCall> for VertexAiFunctionCall { + type Error = anyhow::Error; + + fn try_from(f: &ChatFunctionCall) -> Result { + Ok(VertexAiFunctionCall { + name: f.name.clone(), + args: f.arguments.clone(), + }) + } } #[derive(Serialize, Deserialize, Debug, Clone)] @@ -52,23 +102,57 @@ impl TryFrom<&ChatMessage> for Content { fn try_from(m: &ChatMessage) -> Result { Ok(Content { role: match m.role { - ChatMessageRole::Assistant | ChatMessageRole::Function => String::from("MODEL"), + ChatMessageRole::Assistant => String::from("MODEL"), + ChatMessageRole::Function => match m.function_call { + // Role "function" is reserved for function responses. + None if USE_FUNCTION_CALLING => String::from("FUNCTION"), + None => String::from("USER"), + // Function calls are done as role "model". + Some(_) => String::from("MODEL"), + }, _ => String::from("USER"), }, parts: vec![Part { text: match m.role { - ChatMessageRole::System => format!( - "SYSTEM: {}\n", + ChatMessageRole::System => Some(format!( + "[user: SYSTEM] {}\n", m.content.clone().unwrap_or(String::from("")) - ), - _ => match m.name { - Some(ref name) => format!( - "[name: {}]: {}", + )), + ChatMessageRole::User => match m.name { + Some(ref name) => Some(format!( + "[user: {}] {}", name, m.content.clone().unwrap_or(String::from("")) - ), - None => m.content.clone().unwrap_or(String::from("")), + )), + None => Some(m.content.clone().unwrap_or(String::from(""))), + }, + ChatMessageRole::Function if USE_FUNCTION_CALLING => None, + ChatMessageRole::Function => match m.name { + Some(ref name) => Some(format!( + "[function_result: {}] {}", + name, + m.content.clone().unwrap_or(String::from("")) + )), + None => Some(format!( + "[function_result] {}", + m.content.clone().unwrap_or(String::from("")) + )), }, + ChatMessageRole::Assistant => { + Some(m.content.clone().unwrap_or(String::from(""))) + } + }, + function_call: match m.function_call.clone() { + Some(function_call) if USE_FUNCTION_CALLING => { + VertexAiFunctionCall::try_from(&function_call).ok() + } + _ => None, + }, + function_response: match m.role { + ChatMessageRole::Function if USE_FUNCTION_CALLING => { + VertexAiFunctionResponse::try_from(m).ok() + } + _ => None, }, }], }) @@ -85,7 +169,7 @@ pub struct Candidate { #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] pub struct Completion { - candidates: Vec, + candidates: Option>, usage_metadata: Option, } @@ -265,7 +349,9 @@ impl LLM for GoogleVertexAiLLM { &vec![Content { role: String::from("USER"), parts: vec![Part { - text: String::from(prompt), + text: Some(String::from(prompt)), + function_call: None, + function_response: None, }], }], &vec![], @@ -287,7 +373,16 @@ impl LLM for GoogleVertexAiLLM { provider: ProviderID::GoogleVertexAi.to_string(), model: self.id().clone(), completions: vec![Tokens { - text: c.candidates[0].content.parts[0].text.clone(), + text: match c.candidates { + None => String::from(""), + Some(candidates) => match candidates.len() { + 0 => String::from(""), + _ => candidates[0].content.parts[0] + .text + .clone() + .unwrap_or_default(), + }, + }, tokens: Some(vec![]), logprobs: Some(vec![]), top_logprobs: None, @@ -327,10 +422,12 @@ impl LLM for GoogleVertexAiLLM { } if functions.len() > 0 || function_call.is_some() { - Err(anyhow!( - "Functions on Google Vertex AI are not implemented yet." - ))?; + if USE_FUNCTION_CALLING { + unimplemented!("Functions on Google Vertex AI are not implemented yet."); + } + Err(anyhow!("Functions on Google Vertex AI are disabled."))?; } + if frequency_penalty.is_some() { Err(anyhow!( "Frequency penalty not supported by Google Vertex AI" @@ -375,7 +472,18 @@ impl LLM for GoogleVertexAiLLM { name: None, function_call: None, role: ChatMessageRole::Assistant, - content: Some(c.candidates[0].content.parts[0].text.clone()), + content: match c.candidates { + None => None, + Some(candidates) => match candidates.len() { + 0 => None, + _ => Some( + candidates[0].content.parts[0] + .text + .clone() + .unwrap_or_default(), + ), + }, + }, }], }) } @@ -397,17 +505,33 @@ pub async fn streamed_chat_completion( let https = HttpsConnector::new(); let url = uri.to_string(); - // Squash messages for the same role. - // Gemini doesn't allow multiple messages from the same role in a row. + // Ensure that all input message have one single part. + messages + .iter() + .map(|m| match m.parts.len() { + 0 => Err(anyhow!("Message has no parts")), + 1 => Ok(()), + _ => Err(anyhow!("Message has more than one part")), + }) + .collect::>>()?; + + // Squash user messages. + // Gemini doesn't allow multiple user or assistant messages in a row. let messages: Vec = messages .iter() .fold( + // First we merge consecutive user/assistant messages by making them a multi-part message. Vec::::new(), |mut acc: Vec, m: &Content| { match acc.last_mut() { - Some(last) if last.role == m.role => { + Some(last) + if last.role == m.role + && ["MODEL", "USER"].contains(&m.role.to_uppercase().as_str()) => + { last.parts.push(Part { text: m.parts[0].text.clone(), + function_call: None, + function_response: None, }); } _ => { @@ -418,16 +542,23 @@ pub async fn streamed_chat_completion( }, ) .iter() - .map(|m| Content { - role: m.role.clone(), - parts: vec![Part { - text: m - .parts - .iter() - .map(|p| p.text.clone()) - .collect::>() - .join(" "), - }], + // Then we squash the parts together. + .map(|m| match m.role.to_uppercase().as_str() { + "USER" | "MODEL" => Content { + role: m.role.clone(), + parts: vec![Part { + text: Some( + m.parts + .iter() + .map(|p| p.text.clone().unwrap_or_default()) + .collect::>() + .join("\n"), + ), + function_call: None, + function_response: None, + }], + }, + _ => m.clone(), }) .collect::>(); @@ -435,7 +566,7 @@ pub async fn streamed_chat_completion( Ok(builder) => builder, Err(e) => { return Err(anyhow!( - "Error creating Anthropic streaming client: {:?}", + "Error creating Google Vertex AI streaming client: {:?}", e )) } @@ -492,33 +623,66 @@ pub async fn streamed_chat_completion( } Some(es::SSE::Event(e)) => { let completion: Completion = serde_json::from_str(e.data.as_str())?; - if completion.candidates.len() != 1 { - Err(anyhow!( - "Unexpected number of candidates: {}", - completion.candidates.len() - ))?; - } - if completion.candidates[0].content.parts.len() != 1 { - Err(anyhow!( - "Unexpected number of parts: {}", - completion.candidates[0].content.parts.len() - ))?; - } + let completion_candidates = completion.candidates.clone().unwrap_or_default(); + + match completion_candidates.len() { + 0 => { + break 'stream; + } + 1 => (), + n => { + Err(anyhow!("Unexpected number of candidates: {}", n))?; + } + }; + + match completion_candidates[0].content.parts.len() { + 1 => (), + n => { + Err(anyhow!("Unexpected number of parts: {}", n))?; + } + }; match event_sender.as_ref() { Some(sender) => { - let text = completion.candidates[0].content.parts[0].text.clone(); - if text.len() > 0 { - let _ = sender.send(json!({ - "type": "tokens", - "content": { - "text": text, + // let text = completion.candidates[0].content.parts[0].text.clone(); + match completion_candidates[0].content.parts[0].text.clone() { + Some(t) => { + if t.len() > 0 { + let _ = sender.send(json!({ + "type": "tokens", + "content": { + "text": t, + } + })); } - })); + } + None => (), + } + + match completion_candidates[0].content.parts[0] + .function_call + .clone() + { + Some(f) => { + let _ = sender.send(json!({ + "type": "function_call", + "content": { + "name": f.name, + } + })); + let _ = sender.send(json!({ + "type": "function_call_arguments_tokens", + "content": { + "text": f.args, + } + })); + } + None => (), } } + _ => (), - } + }; completions.lock().push(completion); } @@ -536,41 +700,135 @@ pub async fn streamed_chat_completion( } } - let mut completion = Completion { - candidates: vec![Candidate { + let completions_lock = completions.lock(); + + // Sometimes (usually when last message is Assistant), the AI decides not to respond. + if completions_lock.len() == 0 { + return Ok(Completion { + candidates: None, + usage_metadata: None, + }); + } + + // Ensure that we don't have a mix of `function_call` and `text` in the same completion. + // Ensure that all the roles are "MODEL" + // We merge all the completions texts together. + let mut full_completion_text = String::from(""); + let mut function_call_name = String::from(""); + let mut function_call_args = String::from(""); + let mut finish_reason = String::from(""); + let mut usage_metadata = UsageMetadata { + prompt_token_count: 0, + candidates_token_count: None, + total_token_count: 0, + }; + for c in completions_lock.iter() { + match &c.usage_metadata { + None => (), + Some(um) => { + usage_metadata.prompt_token_count = um.prompt_token_count; + usage_metadata.candidates_token_count = um.candidates_token_count; + usage_metadata.total_token_count = um.total_token_count; + } + } + match &c.candidates { + None => (), + Some(candidates) => match candidates.len() { + 0 => (), + 1 => { + match candidates[0].content.role.to_uppercase().as_str() { + "MODEL" => (), + _ => Err(anyhow!(format!( + "Unexpected role in completion: {}", + candidates[0].content.role + )))?, + }; + match &candidates[0].finish_reason { + None => (), + Some(r) => { + match finish_reason.len() { + 0 => finish_reason = r.clone(), + _ => Err(anyhow!("Unexpected finish reason"))?, + }; + } + } + match candidates[0].content.parts.len() { + 0 => (), + 1 => { + match candidates[0].content.parts[0].text.clone() { + Some(t) => { + if function_call_name.len() > 0 || function_call_args.len() > 0 + { + Err(anyhow!("Unexpected text in function call"))?; + } + full_completion_text.push_str(t.as_str()); + } + None => (), + }; + match candidates[0].content.parts[0].function_call.clone() { + Some(f) => { + if full_completion_text.len() > 0 { + Err(anyhow!("Unexpected function call in text"))?; + } + match f.name.len() { + 0 => (), + _ if function_call_name.is_empty() => { + function_call_name = f.name.clone(); + } + _ => { + if function_call_name != f.name { + Err(anyhow!("Function call name mismatch"))?; + } + } + } + match f.args.len() { + 0 => (), + _ if function_call_args.is_empty() => { + function_call_args.push_str(f.args.as_str()); + } + _ => (), + } + } + None => (), + } + } + _ => (), + } + } + _ => Err(anyhow!("Unexpected number of candidates"))?, + }, + } + } + + if finish_reason.len() == 0 { + Err(anyhow!("No finish reason"))?; + } + + if function_call_name.len() == 0 && full_completion_text.len() == 0 { + Err(anyhow!("No text and no function call"))?; + } + + Ok(Completion { + candidates: Some(vec![Candidate { content: Content { role: String::from("MODEL"), parts: vec![Part { - text: String::from(""), + text: match full_completion_text.len() { + 0 => None, + _ => Some(full_completion_text), + }, + function_call: match function_call_name.len() { + 0 => None, + _ => Some(VertexAiFunctionCall { + name: function_call_name, + args: function_call_args, + }), + }, + function_response: None, }], }, - finish_reason: None, - }], - usage_metadata: Some(UsageMetadata { - prompt_token_count: 0, - candidates_token_count: 0, - total_token_count: 0, - }), - }; - - completions.lock().iter().for_each(|c| { - completion.candidates[0].content.parts[0].text.push_str( - c.candidates[0] - .content - .parts - .iter() - .map(|p| p.text.as_str()) - .collect::>() - .join(" ") - .as_str(), - ); - if c.candidates[0].finish_reason.is_some() { - completion.candidates[0].finish_reason = c.candidates[0].finish_reason.clone(); - } - if c.usage_metadata.is_some() { - completion.usage_metadata = c.usage_metadata.clone(); - } - }); - - Ok(completion) + finish_reason: Some(finish_reason), + }]), + usage_metadata: Some(usage_metadata), + }) } diff --git a/front/components/assistant/AssistantActions.tsx b/front/components/assistant/AssistantActions.tsx index b1b4769d7dfc..308400cedfb8 100644 --- a/front/components/assistant/AssistantActions.tsx +++ b/front/components/assistant/AssistantActions.tsx @@ -126,7 +126,7 @@ export function RemoveAssistantFromListDialog({ }); } else { sendNotification({ - title: `Assistant removed`, + title: `Assistant removed from your list`, type: "success", }); onRemove(); diff --git a/front/components/assistant/AssistantDetails.tsx b/front/components/assistant/AssistantDetails.tsx index 99d0d7dfeb2e..7d9f232e045c 100644 --- a/front/components/assistant/AssistantDetails.tsx +++ b/front/components/assistant/AssistantDetails.tsx @@ -4,11 +4,11 @@ import { ClipboardIcon, CloudArrowDownIcon, CommandLineIcon, - DashIcon, Modal, PlusIcon, ServerIcon, TrashIcon, + XMarkIcon, } from "@dust-tt/sparkle"; import { AgentUserListStatus, @@ -30,13 +30,12 @@ import Link from "next/link"; import { useContext, useState } from "react"; import ReactMarkdown from "react-markdown"; +import { DeleteAssistantDialog } from "@app/components/assistant/AssistantActions"; import { SendNotificationsContext } from "@app/components/sparkle/Notification"; import { CONNECTOR_CONFIGURATIONS } from "@app/lib/connector_providers"; import { useApp, useDatabase } from "@app/lib/swr"; import { PostAgentListStatusRequestBody } from "@app/pages/api/w/[wId]/members/me/agent_list_status"; -import { DeleteAssistantDialog } from "./AssistantActions"; - type AssistantDetailsFlow = "personal" | "workspace"; export function AssistantDetails({ @@ -296,7 +295,11 @@ function ButtonsSection({ }); } else { sendNotification({ - title: `Assistant ${listStatus === "in-list" ? "added" : "removed"}`, + title: `Assistant ${ + listStatus === "in-list" + ? "added to your list" + : "removed from your list" + }`, type: "success", }); onUpdate(); @@ -305,30 +308,31 @@ function ButtonsSection({ setIsAddingOrRemoving(false); onClose(); }; - return ( - -