diff --git a/core/bin/dust_api.rs b/core/bin/dust_api.rs index 3234eb57eb80..5ebff27ea641 100644 --- a/core/bin/dust_api.rs +++ b/core/bin/dust_api.rs @@ -1611,7 +1611,7 @@ async fn tokenize( extract::Json(payload): extract::Json, ) -> (StatusCode, Json) { let embedder = provider(payload.provider_id).embedder(payload.model_id); - match embedder.encode(&payload.text).await { + match embedder.tokenize(payload.text).await { Err(e) => error_response( StatusCode::INTERNAL_SERVER_ERROR, "internal_server_error", diff --git a/core/src/providers/ai21.rs b/core/src/providers/ai21.rs index 0b7dfe68b6fc..0e157394705f 100644 --- a/core/src/providers/ai21.rs +++ b/core/src/providers/ai21.rs @@ -385,6 +385,10 @@ impl Embedder for AI21Embedder { Err(anyhow!("Encode/Decode not implemented for provider `ai21`")) } + async fn tokenize(&self, _text: String) -> Result> { + Err(anyhow!("Tokenize not implemented for provider `ai21`")) + } + async fn embed(&self, _text: Vec<&str>, _extras: Option) -> Result> { Err(anyhow!("Embeddings not available for provider `ai21`")) } diff --git a/core/src/providers/anthropic.rs b/core/src/providers/anthropic.rs index 48498096a8f2..12739d688b9d 100644 --- a/core/src/providers/anthropic.rs +++ b/core/src/providers/anthropic.rs @@ -631,6 +631,10 @@ impl Embedder for AnthropicEmbedder { )) } + async fn tokenize(&self, _text: String) -> Result> { + Err(anyhow!("Tokenize not implemented for provider `anthropic`")) + } + async fn embed(&self, _text: Vec<&str>, _extras: Option) -> Result> { Err(anyhow!("Embeddings not available for provider `anthropic`")) } diff --git a/core/src/providers/azure_openai.rs b/core/src/providers/azure_openai.rs index d7df2904976d..396b30ff80d2 100644 --- a/core/src/providers/azure_openai.rs +++ b/core/src/providers/azure_openai.rs @@ -600,6 +600,10 @@ impl Embedder for AzureOpenAIEmbedder { Ok(str) } + async fn tokenize(&self, _text: String) -> Result> { + Err(anyhow!("Tokenize not implemented for provider `anthropic`")) + } + async fn embed(&self, text: Vec<&str>, extras: Option) -> Result> { let e = embed( self.uri()?, diff --git a/core/src/providers/cohere.rs b/core/src/providers/cohere.rs index 2e00e894a877..19a73ca0f992 100644 --- a/core/src/providers/cohere.rs +++ b/core/src/providers/cohere.rs @@ -534,6 +534,10 @@ impl Embedder for CohereEmbedder { api_decode(self.api_key.as_ref().unwrap(), tokens).await } + async fn tokenize(&self, _text: String) -> Result> { + Err(anyhow!("Tokenize not implemented for provider `Cohere`")) + } + async fn embed(&self, text: Vec<&str>, _extras: Option) -> Result> { assert!(self.api_key.is_some()); diff --git a/core/src/providers/embedder.rs b/core/src/providers/embedder.rs index 1257fb48f894..86b0f69a4fbc 100644 --- a/core/src/providers/embedder.rs +++ b/core/src/providers/embedder.rs @@ -26,6 +26,8 @@ pub trait Embedder { async fn encode(&self, text: &str) -> Result>; async fn decode(&self, tokens: Vec) -> Result; + async fn tokenize(&self, text: String) -> Result>; + async fn embed(&self, text: Vec<&str>, extras: Option) -> Result>; } diff --git a/core/src/providers/openai.rs b/core/src/providers/openai.rs index da7df8310794..19542723e52f 100644 --- a/core/src/providers/openai.rs +++ b/core/src/providers/openai.rs @@ -28,7 +28,7 @@ use tokio::sync::mpsc::UnboundedSender; use tokio::time::timeout; use super::llm::{ChatFunction, ChatFunctionCall}; -use super::tiktoken::tiktoken::{decode_async, encode_async}; +use super::tiktoken::tiktoken::{decode_async, encode_async, tokenize_async}; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Usage { @@ -1575,6 +1575,10 @@ impl Embedder for OpenAIEmbedder { decode_async(self.tokenizer(), tokens).await } + async fn tokenize(&self, text: String) -> Result> { + tokenize_async(self.tokenizer(), text).await + } + async fn embed(&self, text: Vec<&str>, extras: Option) -> Result> { let e = embed( self.uri()?, diff --git a/core/src/providers/tiktoken/tiktoken.rs b/core/src/providers/tiktoken/tiktoken.rs index d6b00a14e17a..0c70c11ee9ad 100644 --- a/core/src/providers/tiktoken/tiktoken.rs +++ b/core/src/providers/tiktoken/tiktoken.rs @@ -112,6 +112,14 @@ pub async fn encode_async(bpe: Arc>, text: &str) -> Result>, + text: String, +) -> Result> { + let r = task::spawn_blocking(move || bpe.lock().tokenize(&text)).await?; + Ok(r) +} + fn _byte_pair_merge(piece: &[u8], ranks: &HashMap, usize>) -> Vec> { let mut parts: Vec<_> = (0..piece.len()).map(|i| i..i + 1).collect(); @@ -241,6 +249,49 @@ impl CoreBPE { ret } + fn _tokenize(&self, text: &String) -> Vec<(usize, String)> { + let regex = self._get_regex(); + let mut results = vec![]; + + for mat in regex.find_iter(text) { + let string = mat.unwrap().as_str(); + let piece = string.as_bytes(); + if let Some(token) = self.encoder.get(piece) { + results.push((*token, string.to_string())); + continue; + } + + results.extend(Self::_tokenize_byte_pair_encode(piece, &self.encoder)); + } + results + } + + /** + * Implemented to match the logic in _encode_ordinary_native + * Used in tokenize function + */ + pub fn _tokenize_byte_pair_encode( + piece: &[u8], + ranks: &HashMap, usize>, + ) -> Vec<(usize, String)> { + if piece.len() == 1 { + let string = std::str::from_utf8(&piece).unwrap(); + return vec![(ranks[piece], string.to_string())]; + } + + _byte_pair_merge(piece, ranks) + .iter() + .map(|p| { + ( + ranks[&piece[p.start..p.end]], + std::str::from_utf8(&piece[p.start..p.end]) + .unwrap() + .to_string(), + ) + }) + .collect() + } + fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec, usize) { let special_regex = self._get_special_regex(); let regex = self._get_regex(); @@ -506,6 +557,10 @@ impl CoreBPE { self._encode_native(text, &allowed_special).0 } + pub fn tokenize(&self, text: &String) -> Vec<(usize, String)> { + self._tokenize(text) + } + pub fn encode_with_special_tokens(&self, text: &str) -> Vec { let allowed_special = self .special_tokens_encoder diff --git a/front/lib/core_api.ts b/front/lib/core_api.ts index 20f060ffc0a1..561e70def590 100644 --- a/front/lib/core_api.ts +++ b/front/lib/core_api.ts @@ -88,6 +88,8 @@ export type CoreAPIRun = { traces: Array<[[BlockType, string], Array>]>; }; +export type CoreAPITokenType = [number, string]; + type CoreAPICreateRunParams = { projectId: string; runAsWorkspaceId: string; @@ -726,7 +728,7 @@ export const CoreAPI = { text: string; modelId: string; providerId: string; - }): Promise> { + }): Promise> { const response = await fetch(`${CORE_API}/tokenize`, { method: "POST", headers: { diff --git a/front/lib/extract_event_app.ts b/front/lib/extract_event_app.ts index b7ae45733c50..2eaa9a79efcf 100644 --- a/front/lib/extract_event_app.ts +++ b/front/lib/extract_event_app.ts @@ -4,13 +4,12 @@ import { } from "@app/lib/actions/registry"; import { runAction } from "@app/lib/actions/server"; import { Authenticator } from "@app/lib/auth"; -import { CoreAPI } from "@app/lib/core_api"; +import { CoreAPI, CoreAPITokenType } from "@app/lib/core_api"; +import { findMarkersIndexes } from "@app/lib/extract_event_markers"; import { formatPropertiesForModel } from "@app/lib/extract_events_properties"; import logger from "@app/logger/logger"; import { EventSchemaType } from "@app/types/extract"; -const EXTRACT_MAX_NUMBER_TOKENS_TO_PROCESS = 6000; - export type ExtractEventAppResponseResults = { value: { results: { value: string }[][]; @@ -65,53 +64,124 @@ export async function _runExtractEventApp({ } /** - * Return the content to process by the Extract Event app. - * If the document is too big, we send only part of it to the Dust App. - * @param fullDocumentText - * @param marker + * Gets the maximum text content to process for the Dust app. + * We define a maximum number of tokens that the Dust app can process. + * It will return the text around the marker: first we expand the text before the marker, than after the marker. */ export async function _getMaxTextContentToProcess({ - fullDocumentText, + fullText, marker, }: { - fullDocumentText: string; + fullText: string; marker: string; }): Promise { - const tokensInDocumentText = await CoreAPI.tokenize({ - text: fullDocumentText, + const tokenized = await getTokenizedText(fullText); + const tokens = tokenized.tokens; + const nbTokens = tokens.length; + const MAX_TOKENS = 6000; + + // If the text is small enough, just return it + if (nbTokens < MAX_TOKENS) { + return fullText; + } + + // Otherwise we extract the tokens around the marker + // and return the text corresponding to those tokens + const extractTokensResult = extractMaxTokens({ + fullText, + tokens, + marker, + maxTokens: MAX_TOKENS, + }); + + return extractTokensResult.map((t) => t[1]).join(""); +} + +/** + * Extracts the maximum number of tokens around the marker. + */ +function extractMaxTokens({ + fullText, + tokens, + marker, + maxTokens, +}: { + fullText: string; + tokens: CoreAPITokenType[]; + marker: string; + maxTokens: number; +}): CoreAPITokenType[] { + const { start, end } = findMarkersIndexes({ fullText, marker, tokens }); + + if (start === -1 || end === -1) { + return []; + } + + // The number of tokens that the marker takes up + const markerTokens = end - start + 1; + + // The number of remaining tokens that can be included around the marker + const remainingTokens = maxTokens - markerTokens; + + // Initialize the slicing start and end points around the marker + let startSlice = start; + let endSlice = end; + + // Try to add tokens before the marker first + if (remainingTokens > 0) { + startSlice = Math.max(0, start - remainingTokens); + + // Calculate any remaining tokens that can be used after the marker + const remainingAfter = remainingTokens - (start - startSlice); + + // If there are any tokens left, add them after the marker + if (remainingAfter > 0) { + endSlice = Math.min(tokens.length - 1, end + remainingAfter); + } + } + + return tokens.slice(startSlice, endSlice + 1); +} + +/** + * Calls Core API to get the tokens and associated strings for a given text. + * Ex: "Un petit Soupinou des bois [[idea:2]]" will return: + * { + * tokens: [ + [ 1844, 'Un' ], + [ 46110, ' petit' ], + [ 9424, ' Sou' ], + [ 13576, 'pin' ], + [ 283, 'ou' ], + [ 951, ' des' ], + [ 66304, ' bois' ], + [ 4416, ' [[' ], + [ 42877, 'idea' ], + [ 25, ':' ], + [ 17, '2' ], + [ 5163, ']]' ] + ], + * } + */ + +export async function getTokenizedText( + text: string +): Promise<{ tokens: CoreAPITokenType[] }> { + console.log("computeNbTokens4"); + const tokenizeResponse = await CoreAPI.tokenize({ + text: text, modelId: "text-embedding-ada-002", providerId: "openai", }); - if (tokensInDocumentText.isErr()) { + if (tokenizeResponse.isErr()) { { - tokensInDocumentText.error; + tokenizeResponse.error; } logger.error( "Could not get number of tokens for document, trying with full doc." ); - return fullDocumentText; - } - - const numberOfTokens = tokensInDocumentText.value.tokens.length; - let documentTextToProcess: string; - - if (numberOfTokens > EXTRACT_MAX_NUMBER_TOKENS_TO_PROCESS) { - // Document is too big, we need to send only part of it to the Dust App. - const fullDocLength = fullDocumentText.length; - const markerIndex = fullDocumentText.indexOf(marker); - const markerLength = marker.length; - - // We can go half the max number of tokens on each side of the marker. - // We multiply by 4 because we assume 1 token is approximately 4 characters - const maxLength = (EXTRACT_MAX_NUMBER_TOKENS_TO_PROCESS / 2) * 4; - - const start = Math.max(0, markerIndex - maxLength); - const end = Math.min(fullDocLength, markerIndex + markerLength + maxLength); - documentTextToProcess = fullDocumentText.substring(start, end); - } else { - // Document is small enough, we send the whole text. - documentTextToProcess = fullDocumentText; + return { tokens: [] }; } - return documentTextToProcess; + return tokenizeResponse.value; } diff --git a/front/lib/extract_event_markers.ts b/front/lib/extract_event_markers.ts index a626622e3f4a..b1041c874aa4 100644 --- a/front/lib/extract_event_markers.ts +++ b/front/lib/extract_event_markers.ts @@ -2,6 +2,8 @@ import { Op } from "sequelize"; import { ExtractedEvent } from "@app/lib/models"; +import { CoreAPITokenType } from "./core_api"; + const EXTRACT_EVENT_PATTERN = /\[\[(.*?)\]\]/; // Ex: [[event]] /** @@ -84,3 +86,57 @@ export async function getExtractEventMarkersToProcess({ (rawMarker) => !existingExtractedEventMarkers.includes(rawMarker) ); } + +/** + * Gets the indexes of the tokens corresponding to the marker. + * Example, for params: + * - full_text: Un petit Soupinou des bois [[idea:2]] + * - tokens: [ + [ 1844, 'Un' ], + [ 46110, ' petit' ], + [ 9424, ' Sou' ], + [ 13576, 'pin' ], + [ 283, 'ou' ], + [ 951, ' des' ], + [ 66304, ' bois' ], + [ 4416, ' [[' ], + [ 42877, 'idea' ], + [ 25, ':' ], + [ 17, '2' ], + [ 5163, ']]' ] + ] + * - marker: "[[idea:2]]" + * Will return { start: 7, end: 11 } + */ +export function findMarkersIndexes({ + fullText, + marker, + tokens, +}: { + fullText: string; + marker: string; + tokens: CoreAPITokenType[]; +}): { start: number; end: number } { + const markerIndex = fullText.indexOf(marker); + if (markerIndex === -1) return { start: -1, end: -1 }; + + let charCount = 0; + let startIndex = -1; + let endIndex = -1; + + for (let i = 0; i < tokens.length; i++) { + const str = tokens[i][1]; + charCount += str.length; + + if (startIndex === -1 && charCount > markerIndex) { + startIndex = i; + } + + if (charCount >= markerIndex + marker.length) { + endIndex = i; + break; + } + } + + return { start: startIndex, end: endIndex }; +} diff --git a/front/lib/extract_events.ts b/front/lib/extract_events.ts index 617fe94f0997..66a23cf61e4a 100644 --- a/front/lib/extract_events.ts +++ b/front/lib/extract_events.ts @@ -147,7 +147,7 @@ async function _processExtractEventsForMarker({ // 2/ Check that the document is not to big for the Dust App. const contentToProcess = await _getMaxTextContentToProcess({ - fullDocumentText: documentText, + fullText: documentText, marker: marker, }); diff --git a/front/tests/lib/markers.test.ts b/front/tests/lib/markers.test.ts index 09adf93e1b34..8660410060da 100644 --- a/front/tests/lib/markers.test.ts +++ b/front/tests/lib/markers.test.ts @@ -1,4 +1,6 @@ +import { CoreAPITokenType } from "@app/lib/core_api"; import { + findMarkersIndexes, getRawExtractEventMarkersFromText, hasExtractEventMarker, } from "@app/lib/extract_event_markers"; @@ -74,3 +76,80 @@ describe("Test getExtractEventMarker", function () { }); }); }); + +describe("Test findMarkerIndexes", function () { + test("findMarkerIndexes", function () { + const fullTextSoupinou = "Un petit Soupinou des bois [[idea:2]]"; + const tokensSoupinou: CoreAPITokenType[] = [ + [1844, "Un"], + [46110, " petit"], + [9424, " Sou"], + [13576, "pin"], + [283, "ou"], + [951, " des"], + [66304, " bois"], + [4416, " [["], + [42877, "idea"], + [25, ":"], + [17, "2"], + [5163, "]]"], + ]; + + const fullTextSticious = + "I’m not superstitious [[office_quote]] but I am a little stitious. [[office_quote]]"; + const tokensStitious: CoreAPITokenType[] = [ + [40, "I"], + [4344, "’m"], + [539, " not"], + [2307, " super"], + [3781, "stit"], + [1245, "ious"], + [4416, " [["], + [27614, "office"], + [46336, "_quote"], + [21128, "]],"], + [719, " but"], + [358, " I"], + [1097, " am"], + [264, " a"], + [2697, " little"], + [357, " st"], + [65795, "itious"], + [13, "."], + [4416, " [["], + [27614, "office"], + [46336, "_quote"], + [5163, "]]"], + ]; + + const cases = [ + { + fullText: fullTextSoupinou, + marker: "[[idea:2]]", + tokens: tokensSoupinou, + expected: { start: 7, end: 11 }, // main case + }, + { + fullText: fullTextSoupinou, + marker: "[[idea]]", + tokens: tokensSoupinou, + expected: { start: -1, end: -1 }, // not found + }, + { + fullText: fullTextSticious, + marker: "[[office_quote]]", + tokens: tokensStitious, + expected: { start: 6, end: 9 }, // takes the first one + }, + ]; + cases.forEach((c) => { + expect( + findMarkersIndexes({ + fullText: c.fullText, + marker: c.marker, + tokens: c.tokens, + }) + ).toEqual(c.expected); + }); + }); +});