Skip to content

Commit

Permalink
Merge branch 'main' into SearchPolish
Browse files Browse the repository at this point in the history
  • Loading branch information
Duncid authored Oct 3, 2023
2 parents 787c5ee + 9211f98 commit 085409e
Show file tree
Hide file tree
Showing 29 changed files with 892 additions and 147 deletions.
70 changes: 57 additions & 13 deletions connectors/src/connectors/slack/bot.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import {
AgentActionType,
AgentGenerationSuccessEvent,
AgentMessageType,
DustAPI,
RetrievalDocumentType,
} from "@connectors/lib/dust_api";
import {
Connector,
Expand Down Expand Up @@ -161,7 +163,7 @@ async function botAnswerMessage(
});
const mainMessage = await slackClient.chat.postMessage({
channel: slackChannel,
text: "_I am thinking..._",
text: "_Thinking..._",
thread_ts: slackMessageTs,
mrkdwn: true,
});
Expand Down Expand Up @@ -238,6 +240,7 @@ async function botAnswerMessage(
}

let fullAnswer = "";
let action: AgentActionType | null = null;
let lastSentDate = new Date();
for await (const event of streamRes.value.eventStream) {
switch (event.type) {
Expand All @@ -258,24 +261,31 @@ async function botAnswerMessage(
}
case "generation_tokens": {
fullAnswer += event.text;
if (lastSentDate.getTime() + 1000 > new Date().getTime()) {
if (lastSentDate.getTime() + 1500 > new Date().getTime()) {
continue;
}
lastSentDate = new Date();

let finalAnswer = _processCiteMention(fullAnswer, action);
finalAnswer += `...\n\n <${DUST_API}/w/${connector.workspaceId}/assistant/${conversation.sId}|Continue this conversation on Dust>`;

await slackClient.chat.update({
channel: slackChannel,
text: fullAnswer,
text: finalAnswer,
ts: mainMessage.ts as string,
thread_ts: slackMessageTs,
});
break;
}
case "agent_action_success": {
action = event.action;
break;
}
case "agent_generation_success": {
const finalAnswer = `${_removeCiteMention(
event.text
)}\n\n <${DUST_API}/w/${connector.workspaceId}/assistant/${
conversation.sId
}|Continue this conversation on Dust>`;
fullAnswer = event.text;

let finalAnswer = _processCiteMention(fullAnswer, action);
finalAnswer += `\n\n <${DUST_API}/w/${connector.workspaceId}/assistant/${conversation.sId}|Continue this conversation on Dust>`;

await slackClient.chat.update({
channel: slackChannel,
Expand All @@ -293,11 +303,45 @@ async function botAnswerMessage(
return new Err(new Error("Failed to get the final answer from Dust"));
}

/*
* Temp > until I have a PR to properly handle mentions
*/
function _removeCiteMention(message: string) {
const regex = /:cite\[[a-zA-Z0-9,]+\]/g;
function _processCiteMention(
content: string,
action: AgentActionType | null
): string {
const references: { [key: string]: RetrievalDocumentType } = {};

if (action && action.type === "retrieval_action" && action.documents) {
action.documents.forEach((d) => {
references[d.reference] = d;
});
}

if (references) {
let counter = 0;
const refCounter: { [key: string]: number } = {};
return content.replace(/:cite\[[a-zA-Z0-9, ]+\]/g, (match) => {
const keys = match.slice(6, -1).split(","); // slice off ":cite[" and "]" then split by comma
return keys
.map((key) => {
const k = key.trim();
const ref = references[k];
if (ref && ref.sourceUrl) {
if (!refCounter[k]) {
counter++;
refCounter[k] = counter;
}
return `[<${ref.sourceUrl}|${refCounter[k]}>]`;
}
return "";
})
.join("");
});
}

return _removeCiteMention(content);
}

function _removeCiteMention(message: string): string {
const regex = /:cite\[[a-zA-Z0-9, ]+\]/g;
return message.replace(regex, "");
}

Expand Down
1 change: 0 additions & 1 deletion connectors/src/lib/dust_api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,6 @@ export type AgentMessageType = {
visibility: MessageVisibility;
version: number;
parentMessageId: string | null;

// configuration: AgentConfigurationType;
status: AgentMessageStatus;
action: AgentActionType | null;
Expand Down
6 changes: 3 additions & 3 deletions core/bin/dust_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ async fn run_helper(
None => Err(error_response(
StatusCode::BAD_REQUEST,
"missing_specification_error",
"No specification provided, either `specification` or
"No specification provided, either `specification` or
`specification_hash` must be provided",
None,
))?,
Expand Down Expand Up @@ -1610,8 +1610,8 @@ struct TokenizePayload {
async fn tokenize(
extract::Json(payload): extract::Json<TokenizePayload>,
) -> (StatusCode, Json<APIResponse>) {
let embedder = provider(payload.provider_id).embedder(payload.model_id);
match embedder.tokenize(payload.text).await {
let embedder = provider(payload.provider_id).llm(payload.model_id);
match embedder.tokenize(&payload.text).await {
Err(e) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
Expand Down
1 change: 1 addition & 0 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub mod providers {
pub mod tiktoken;
}
pub mod anthropic;
pub mod textsynth;
}
pub mod http {
pub mod request;
Expand Down
8 changes: 4 additions & 4 deletions core/src/providers/ai21.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ impl LLM for AI21LLM {
Err(anyhow!("Encode/Decode not implemented for provider `ai21`"))
}

async fn tokenize(&self, _text: &str) -> Result<Vec<(usize, String)>> {
Err(anyhow!("Tokenize not implemented for provider `ai21`"))
}

async fn generate(
&self,
prompt: &str,
Expand Down Expand Up @@ -385,10 +389,6 @@ impl Embedder for AI21Embedder {
Err(anyhow!("Encode/Decode not implemented for provider `ai21`"))
}

async fn tokenize(&self, _text: String) -> Result<Vec<(usize, String)>> {
Err(anyhow!("Tokenize not implemented for provider `ai21`"))
}

async fn embed(&self, _text: Vec<&str>, _extras: Option<Value>) -> Result<Vec<EmbedderVector>> {
Err(anyhow!("Embeddings not available for provider `ai21`"))
}
Expand Down
8 changes: 4 additions & 4 deletions core/src/providers/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,10 @@ impl LLM for AnthropicLLM {
decode_async(anthropic_base_singleton(), tokens).await
}

async fn tokenize(&self, text: &str) -> Result<Vec<(usize, String)>> {
tokenize_async(anthropic_base_singleton(), text).await
}

async fn chat(
&self,
messages: &Vec<ChatMessage>,
Expand Down Expand Up @@ -665,10 +669,6 @@ impl Embedder for AnthropicEmbedder {
decode_async(anthropic_base_singleton(), tokens).await
}

async fn tokenize(&self, text: String) -> Result<Vec<(usize, String)>> {
tokenize_async(anthropic_base_singleton(), text).await
}

async fn embed(&self, _text: Vec<&str>, _extras: Option<Value>) -> Result<Vec<EmbedderVector>> {
Err(anyhow!("Embeddings not available for provider `anthropic`"))
}
Expand Down
15 changes: 7 additions & 8 deletions core/src/providers/azure_openai.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::tiktoken::tiktoken::{decode_async, encode_async, tokenize_async};
use crate::providers::embedder::{Embedder, EmbedderVector};
use crate::providers::llm::Tokens;
use crate::providers::llm::{ChatMessage, LLMChatGeneration, LLMGeneration, LLM};
Expand Down Expand Up @@ -265,13 +266,15 @@ impl LLM for AzureOpenAILLM {
}

async fn encode(&self, text: &str) -> Result<Vec<usize>> {
let tokens = { self.tokenizer().lock().encode_with_special_tokens(text) };
Ok(tokens)
encode_async(self.tokenizer(), text).await
}

async fn decode(&self, tokens: Vec<usize>) -> Result<String> {
let str = { self.tokenizer().lock().decode(tokens)? };
Ok(str)
decode_async(self.tokenizer(), tokens).await
}

async fn tokenize(&self, text: &str) -> Result<Vec<(usize, String)>> {
tokenize_async(self.tokenizer(), text).await
}

async fn generate(
Expand Down Expand Up @@ -600,10 +603,6 @@ impl Embedder for AzureOpenAIEmbedder {
Ok(str)
}

async fn tokenize(&self, _text: String) -> Result<Vec<(usize, String)>> {
Err(anyhow!("Tokenize not implemented for provider `anthropic`"))
}

async fn embed(&self, text: Vec<&str>, extras: Option<Value>) -> Result<Vec<EmbedderVector>> {
let e = embed(
self.uri()?,
Expand Down
14 changes: 10 additions & 4 deletions core/src/providers/cohere.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,16 @@ impl LLM for CohereLLM {
api_decode(self.api_key.as_ref().unwrap(), tokens).await
}

// We return empty string in tokenize to partially support the endpoint.
async fn tokenize(&self, text: &str) -> Result<Vec<(usize, String)>> {
assert!(self.api_key.is_some());
let tokens = api_encode(self.api_key.as_ref().unwrap(), text).await?;
Ok(tokens
.iter()
.map(|t| (*t, "".to_string()))
.collect::<Vec<_>>())
}

async fn generate(
&self,
prompt: &str,
Expand Down Expand Up @@ -534,10 +544,6 @@ impl Embedder for CohereEmbedder {
api_decode(self.api_key.as_ref().unwrap(), tokens).await
}

async fn tokenize(&self, _text: String) -> Result<Vec<(usize, String)>> {
Err(anyhow!("Tokenize not implemented for provider `Cohere`"))
}

async fn embed(&self, text: Vec<&str>, _extras: Option<Value>) -> Result<Vec<EmbedderVector>> {
assert!(self.api_key.is_some());

Expand Down
2 changes: 0 additions & 2 deletions core/src/providers/embedder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ pub trait Embedder {
async fn encode(&self, text: &str) -> Result<Vec<usize>>;
async fn decode(&self, tokens: Vec<usize>) -> Result<String>;

async fn tokenize(&self, text: String) -> Result<Vec<(usize, String)>>;

async fn embed(&self, text: Vec<&str>, extras: Option<Value>) -> Result<Vec<EmbedderVector>>;
}

Expand Down
1 change: 1 addition & 0 deletions core/src/providers/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ pub trait LLM {

async fn encode(&self, text: &str) -> Result<Vec<usize>>;
async fn decode(&self, tokens: Vec<usize>) -> Result<String>;
async fn tokenize(&self, text: &str) -> Result<Vec<(usize, String)>>;

async fn generate(
&self,
Expand Down
8 changes: 4 additions & 4 deletions core/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,10 @@ impl LLM for OpenAILLM {
decode_async(self.tokenizer(), tokens).await
}

async fn tokenize(&self, text: &str) -> Result<Vec<(usize, String)>> {
tokenize_async(self.tokenizer(), text).await
}

async fn generate(
&self,
prompt: &str,
Expand Down Expand Up @@ -1575,10 +1579,6 @@ impl Embedder for OpenAIEmbedder {
decode_async(self.tokenizer(), tokens).await
}

async fn tokenize(&self, text: String) -> Result<Vec<(usize, String)>> {
tokenize_async(self.tokenizer(), text).await
}

async fn embed(&self, text: Vec<&str>, extras: Option<Value>) -> Result<Vec<EmbedderVector>> {
let e = embed(
self.uri()?,
Expand Down
6 changes: 6 additions & 0 deletions core/src/providers/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use serde::{Deserialize, Serialize};
use std::str::FromStr;
use std::time::Duration;

use super::textsynth::TextSynthProvider;

#[derive(Debug, Clone, Copy, Serialize, PartialEq, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ProviderID {
Expand All @@ -22,6 +24,7 @@ pub enum ProviderID {
#[serde(rename = "azure_openai")]
AzureOpenAI,
Anthropic,
TextSynth,
}

impl ToString for ProviderID {
Expand All @@ -32,6 +35,7 @@ impl ToString for ProviderID {
ProviderID::AI21 => String::from("ai21"),
ProviderID::AzureOpenAI => String::from("azure_openai"),
ProviderID::Anthropic => String::from("anthropic"),
ProviderID::TextSynth => String::from("textsynth"),
}
}
}
Expand All @@ -45,6 +49,7 @@ impl FromStr for ProviderID {
"ai21" => Ok(ProviderID::AI21),
"azure_openai" => Ok(ProviderID::AzureOpenAI),
"anthropic" => Ok(ProviderID::Anthropic),
"textsynth" => Ok(ProviderID::TextSynth),
_ => Err(ParseError::with_message(
"Unknown provider ID (possible values: openai, cohere, ai21, azure_openai)",
))?,
Expand Down Expand Up @@ -139,5 +144,6 @@ pub fn provider(t: ProviderID) -> Box<dyn Provider + Sync + Send> {
ProviderID::AI21 => Box::new(AI21Provider::new()),
ProviderID::AzureOpenAI => Box::new(AzureOpenAIProvider::new()),
ProviderID::Anthropic => Box::new(AnthropicProvider::new()),
ProviderID::TextSynth => Box::new(TextSynthProvider::new()),
}
}
Loading

0 comments on commit 085409e

Please sign in to comment.