From ec1b94a34985556004aa68ce1fb8bc15ecb23c55 Mon Sep 17 00:00:00 2001 From: Anton Dubovik Date: Thu, 5 Dec 2024 16:32:26 +0000 Subject: [PATCH] feat: supported batching for Titan text and image embeddings (#193) --- aidial_adapter_bedrock/dial_api/response.py | 26 +++++++++++-- .../embedding/amazon/titan_image.py | 39 ++++++++----------- .../embedding/amazon/titan_text.py | 37 ++++++++---------- .../embedding/cohere/embed_text.py | 16 ++------ aidial_adapter_bedrock/llm/model/llama/v3.py | 3 +- 5 files changed, 61 insertions(+), 60 deletions(-) diff --git a/aidial_adapter_bedrock/dial_api/response.py b/aidial_adapter_bedrock/dial_api/response.py index c3b11b89..f46a8101 100644 --- a/aidial_adapter_bedrock/dial_api/response.py +++ b/aidial_adapter_bedrock/dial_api/response.py @@ -5,6 +5,8 @@ from aidial_sdk.embeddings import Usage from pydantic import BaseModel +from aidial_adapter_bedrock.embedding.encoding import vector_to_base64 + class ModelObject(BaseModel): object: Literal["model"] = "model" @@ -16,13 +18,31 @@ class ModelsResponse(BaseModel): data: List[ModelObject] +def _encode_vector( + encoding_format: Literal["float", "base64"], vector: List[float] +) -> List[float] | str: + return vector_to_base64(vector) if encoding_format == "base64" else vector + + def make_embeddings_response( - model: str, vectors: List[List[float] | str], usage: Usage + model: str, + encoding_format: Literal["float", "base64"], + vectors: List[List[float]], + prompt_tokens: int, ) -> EmbeddingsResponse: + embeddings = [_encode_vector(encoding_format, v) for v in vectors] + data: List[Embedding] = [ Embedding(index=index, embedding=embedding) - for index, embedding in enumerate(vectors) + for index, embedding in enumerate(embeddings) ] - return EmbeddingsResponse(model=model, data=data, usage=usage) + return EmbeddingsResponse( + model=model, + data=data, + usage=Usage( + prompt_tokens=prompt_tokens, + total_tokens=prompt_tokens, + ), + ) diff --git a/aidial_adapter_bedrock/embedding/amazon/titan_image.py b/aidial_adapter_bedrock/embedding/amazon/titan_image.py index 87a65f49..407bafba 100644 --- a/aidial_adapter_bedrock/embedding/amazon/titan_image.py +++ b/aidial_adapter_bedrock/embedding/amazon/titan_image.py @@ -5,11 +5,11 @@ https://github.com/aws-samples/amazon-bedrock-samples/blob/5752afb78e7fab49cfd42d38bb09d40756bf0ea0/multimodal/Titan/titan-multimodal-embeddings/rag/1_multimodal_rag.ipynb """ -from typing import AsyncIterator, List, Self +import asyncio +from typing import AsyncIterator, List, Self, Tuple from aidial_sdk.chat_completion import Attachment from aidial_sdk.embeddings import Response as EmbeddingsResponse -from aidial_sdk.embeddings import Usage from aidial_sdk.embeddings.request import EmbeddingsRequest from pydantic import BaseModel @@ -30,7 +30,6 @@ from aidial_adapter_bedrock.embedding.embeddings_adapter import ( EmbeddingsAdapter, ) -from aidial_adapter_bedrock.embedding.encoding import vector_to_base64 from aidial_adapter_bedrock.embedding.validation import ( validate_embeddings_request, ) @@ -155,31 +154,27 @@ async def embeddings( supports_dimensions=True, ) - vectors: List[List[float] | str] = [] - token_count = 0 - - # NOTE: Amazon Titan doesn't support batched inputs - # TODO: create multiple tasks - async for sub_request in get_requests(self.storage, request): + async def compute_embeddings( + req: AmazonRequest, + ) -> Tuple[List[float], int]: embedding, text_tokens = await call_embedding_model( self.client, self.model, - create_titan_request(sub_request, request.dimensions), - ) - - image_tokens = sub_request.get_image_tokens() - - vector = ( - vector_to_base64(embedding) - if request.encoding_format == "base64" - else embedding + create_titan_request(req, request.dimensions), ) + image_tokens = req.get_image_tokens() + return embedding, text_tokens + image_tokens - vectors.append(vector) - token_count += text_tokens + image_tokens + # NOTE: Amazon Titan doesn't support batched inputs + tasks = [ + asyncio.create_task(compute_embeddings(req)) + async for req in get_requests(self.storage, request) + ] + results = await asyncio.gather(*tasks) return make_embeddings_response( model=self.model, - vectors=vectors, - usage=Usage(prompt_tokens=token_count, total_tokens=token_count), + encoding_format=request.encoding_format, + vectors=[r[0] for r in results], + prompt_tokens=sum(r[1] for r in results), ) diff --git a/aidial_adapter_bedrock/embedding/amazon/titan_text.py b/aidial_adapter_bedrock/embedding/amazon/titan_text.py index 2117f7cd..5a2424c9 100644 --- a/aidial_adapter_bedrock/embedding/amazon/titan_text.py +++ b/aidial_adapter_bedrock/embedding/amazon/titan_text.py @@ -5,10 +5,10 @@ https://github.com/aws-samples/amazon-bedrock-samples/blob/5752afb78e7fab49cfd42d38bb09d40756bf0ea0/multimodal/Titan/embeddings/v2/Titan-V2-Embeddings.ipynb """ -from typing import AsyncIterator, List, Self +import asyncio +from typing import AsyncIterator, List, Self, Tuple from aidial_sdk.embeddings import Response as EmbeddingsResponse -from aidial_sdk.embeddings import Usage from aidial_sdk.embeddings.request import EmbeddingsRequest from aidial_adapter_bedrock.bedrock import Bedrock @@ -23,7 +23,6 @@ from aidial_adapter_bedrock.embedding.embeddings_adapter import ( EmbeddingsAdapter, ) -from aidial_adapter_bedrock.embedding.encoding import vector_to_base64 from aidial_adapter_bedrock.embedding.validation import ( validate_embeddings_request, ) @@ -74,27 +73,23 @@ async def embeddings( supports_dimensions=self.supports_dimensions, ) - vectors: List[List[float] | str] = [] - token_count = 0 - - # NOTE: Amazon Titan doesn't support batched inputs - async for text_input in get_text_inputs(request): - sub_request = create_titan_request(text_input, request.dimensions) - embedding, tokens = await call_embedding_model( - self.client, self.model, sub_request + async def compute_embeddings(req: str) -> Tuple[List[float], int]: + return await call_embedding_model( + self.client, + self.model, + create_titan_request(req, request.dimensions), ) - vector = ( - vector_to_base64(embedding) - if request.encoding_format == "base64" - else embedding - ) - - vectors.append(vector) - token_count += tokens + # NOTE: Amazon Titan doesn't support batched inputs + tasks = [ + asyncio.create_task(compute_embeddings(req)) + async for req in get_text_inputs(request) + ] + results = await asyncio.gather(*tasks) return make_embeddings_response( model=self.model, - vectors=vectors, - usage=Usage(prompt_tokens=token_count, total_tokens=token_count), + encoding_format=request.encoding_format, + vectors=[r[0] for r in results], + prompt_tokens=sum(r[1] for r in results), ) diff --git a/aidial_adapter_bedrock/embedding/cohere/embed_text.py b/aidial_adapter_bedrock/embedding/cohere/embed_text.py index c80a1c97..b3db5152 100644 --- a/aidial_adapter_bedrock/embedding/cohere/embed_text.py +++ b/aidial_adapter_bedrock/embedding/cohere/embed_text.py @@ -9,7 +9,6 @@ from typing import AsyncIterator, List, Self from aidial_sdk.embeddings import Response as EmbeddingsResponse -from aidial_sdk.embeddings import Usage from aidial_sdk.embeddings.request import EmbeddingsRequest from aidial_adapter_bedrock.bedrock import Bedrock @@ -24,7 +23,6 @@ from aidial_adapter_bedrock.embedding.embeddings_adapter import ( EmbeddingsAdapter, ) -from aidial_adapter_bedrock.embedding.encoding import vector_to_base64 from aidial_adapter_bedrock.embedding.validation import ( validate_embeddings_request, ) @@ -92,17 +90,9 @@ async def embeddings( self.client, self.model, embedding_request ) - vectors: List[List[float] | str] = [ - ( - vector_to_base64(embedding) - if request.encoding_format == "base64" - else embedding - ) - for embedding in embeddings - ] - return make_embeddings_response( model=self.model, - vectors=vectors, - usage=Usage(prompt_tokens=tokens, total_tokens=tokens), + encoding_format=request.encoding_format, + vectors=embeddings, + prompt_tokens=tokens, ) diff --git a/aidial_adapter_bedrock/llm/model/llama/v3.py b/aidial_adapter_bedrock/llm/model/llama/v3.py index 182203ea..3cbe166c 100644 --- a/aidial_adapter_bedrock/llm/model/llama/v3.py +++ b/aidial_adapter_bedrock/llm/model/llama/v3.py @@ -1,4 +1,5 @@ import json +from typing import Awaitable, Callable from aidial_adapter_bedrock.dial_api.request import ModelParameters from aidial_adapter_bedrock.llm.converse.adapter import ( @@ -26,7 +27,7 @@ def is_stream(self, params: ModelParameters) -> bool: def input_tokenizer_factory( deployment: ConverseDeployment, params: ConverseRequestWrapper -): +) -> Callable[[ConverseMessages], Awaitable[int]]: tool_tokens = default_tokenize_string(json.dumps(params.toolConfig)) system_tokens = default_tokenize_string(json.dumps(params.system))