Skip to content

Commit

Permalink
feat: supported batching for Titan text and image embeddings (#193)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Dec 5, 2024
1 parent 4fbd258 commit ec1b94a
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 60 deletions.
26 changes: 23 additions & 3 deletions aidial_adapter_bedrock/dial_api/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from aidial_sdk.embeddings import Usage
from pydantic import BaseModel

from aidial_adapter_bedrock.embedding.encoding import vector_to_base64


class ModelObject(BaseModel):
object: Literal["model"] = "model"
Expand All @@ -16,13 +18,31 @@ class ModelsResponse(BaseModel):
data: List[ModelObject]


def _encode_vector(
encoding_format: Literal["float", "base64"], vector: List[float]
) -> List[float] | str:
return vector_to_base64(vector) if encoding_format == "base64" else vector


def make_embeddings_response(
model: str, vectors: List[List[float] | str], usage: Usage
model: str,
encoding_format: Literal["float", "base64"],
vectors: List[List[float]],
prompt_tokens: int,
) -> EmbeddingsResponse:

embeddings = [_encode_vector(encoding_format, v) for v in vectors]

data: List[Embedding] = [
Embedding(index=index, embedding=embedding)
for index, embedding in enumerate(vectors)
for index, embedding in enumerate(embeddings)
]

return EmbeddingsResponse(model=model, data=data, usage=usage)
return EmbeddingsResponse(
model=model,
data=data,
usage=Usage(
prompt_tokens=prompt_tokens,
total_tokens=prompt_tokens,
),
)
39 changes: 17 additions & 22 deletions aidial_adapter_bedrock/embedding/amazon/titan_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
https://github.com/aws-samples/amazon-bedrock-samples/blob/5752afb78e7fab49cfd42d38bb09d40756bf0ea0/multimodal/Titan/titan-multimodal-embeddings/rag/1_multimodal_rag.ipynb
"""

from typing import AsyncIterator, List, Self
import asyncio
from typing import AsyncIterator, List, Self, Tuple

from aidial_sdk.chat_completion import Attachment
from aidial_sdk.embeddings import Response as EmbeddingsResponse
from aidial_sdk.embeddings import Usage
from aidial_sdk.embeddings.request import EmbeddingsRequest
from pydantic import BaseModel

Expand All @@ -30,7 +30,6 @@
from aidial_adapter_bedrock.embedding.embeddings_adapter import (
EmbeddingsAdapter,
)
from aidial_adapter_bedrock.embedding.encoding import vector_to_base64
from aidial_adapter_bedrock.embedding.validation import (
validate_embeddings_request,
)
Expand Down Expand Up @@ -155,31 +154,27 @@ async def embeddings(
supports_dimensions=True,
)

vectors: List[List[float] | str] = []
token_count = 0

# NOTE: Amazon Titan doesn't support batched inputs
# TODO: create multiple tasks
async for sub_request in get_requests(self.storage, request):
async def compute_embeddings(
req: AmazonRequest,
) -> Tuple[List[float], int]:
embedding, text_tokens = await call_embedding_model(
self.client,
self.model,
create_titan_request(sub_request, request.dimensions),
)

image_tokens = sub_request.get_image_tokens()

vector = (
vector_to_base64(embedding)
if request.encoding_format == "base64"
else embedding
create_titan_request(req, request.dimensions),
)
image_tokens = req.get_image_tokens()
return embedding, text_tokens + image_tokens

vectors.append(vector)
token_count += text_tokens + image_tokens
# NOTE: Amazon Titan doesn't support batched inputs
tasks = [
asyncio.create_task(compute_embeddings(req))
async for req in get_requests(self.storage, request)
]
results = await asyncio.gather(*tasks)

return make_embeddings_response(
model=self.model,
vectors=vectors,
usage=Usage(prompt_tokens=token_count, total_tokens=token_count),
encoding_format=request.encoding_format,
vectors=[r[0] for r in results],
prompt_tokens=sum(r[1] for r in results),
)
37 changes: 16 additions & 21 deletions aidial_adapter_bedrock/embedding/amazon/titan_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
https://github.com/aws-samples/amazon-bedrock-samples/blob/5752afb78e7fab49cfd42d38bb09d40756bf0ea0/multimodal/Titan/embeddings/v2/Titan-V2-Embeddings.ipynb
"""

from typing import AsyncIterator, List, Self
import asyncio
from typing import AsyncIterator, List, Self, Tuple

from aidial_sdk.embeddings import Response as EmbeddingsResponse
from aidial_sdk.embeddings import Usage
from aidial_sdk.embeddings.request import EmbeddingsRequest

from aidial_adapter_bedrock.bedrock import Bedrock
Expand All @@ -23,7 +23,6 @@
from aidial_adapter_bedrock.embedding.embeddings_adapter import (
EmbeddingsAdapter,
)
from aidial_adapter_bedrock.embedding.encoding import vector_to_base64
from aidial_adapter_bedrock.embedding.validation import (
validate_embeddings_request,
)
Expand Down Expand Up @@ -74,27 +73,23 @@ async def embeddings(
supports_dimensions=self.supports_dimensions,
)

vectors: List[List[float] | str] = []
token_count = 0

# NOTE: Amazon Titan doesn't support batched inputs
async for text_input in get_text_inputs(request):
sub_request = create_titan_request(text_input, request.dimensions)
embedding, tokens = await call_embedding_model(
self.client, self.model, sub_request
async def compute_embeddings(req: str) -> Tuple[List[float], int]:
return await call_embedding_model(
self.client,
self.model,
create_titan_request(req, request.dimensions),
)

vector = (
vector_to_base64(embedding)
if request.encoding_format == "base64"
else embedding
)

vectors.append(vector)
token_count += tokens
# NOTE: Amazon Titan doesn't support batched inputs
tasks = [
asyncio.create_task(compute_embeddings(req))
async for req in get_text_inputs(request)
]
results = await asyncio.gather(*tasks)

return make_embeddings_response(
model=self.model,
vectors=vectors,
usage=Usage(prompt_tokens=token_count, total_tokens=token_count),
encoding_format=request.encoding_format,
vectors=[r[0] for r in results],
prompt_tokens=sum(r[1] for r in results),
)
16 changes: 3 additions & 13 deletions aidial_adapter_bedrock/embedding/cohere/embed_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import AsyncIterator, List, Self

from aidial_sdk.embeddings import Response as EmbeddingsResponse
from aidial_sdk.embeddings import Usage
from aidial_sdk.embeddings.request import EmbeddingsRequest

from aidial_adapter_bedrock.bedrock import Bedrock
Expand All @@ -24,7 +23,6 @@
from aidial_adapter_bedrock.embedding.embeddings_adapter import (
EmbeddingsAdapter,
)
from aidial_adapter_bedrock.embedding.encoding import vector_to_base64
from aidial_adapter_bedrock.embedding.validation import (
validate_embeddings_request,
)
Expand Down Expand Up @@ -92,17 +90,9 @@ async def embeddings(
self.client, self.model, embedding_request
)

vectors: List[List[float] | str] = [
(
vector_to_base64(embedding)
if request.encoding_format == "base64"
else embedding
)
for embedding in embeddings
]

return make_embeddings_response(
model=self.model,
vectors=vectors,
usage=Usage(prompt_tokens=tokens, total_tokens=tokens),
encoding_format=request.encoding_format,
vectors=embeddings,
prompt_tokens=tokens,
)
3 changes: 2 additions & 1 deletion aidial_adapter_bedrock/llm/model/llama/v3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from typing import Awaitable, Callable

from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.llm.converse.adapter import (
Expand Down Expand Up @@ -26,7 +27,7 @@ def is_stream(self, params: ModelParameters) -> bool:

def input_tokenizer_factory(
deployment: ConverseDeployment, params: ConverseRequestWrapper
):
) -> Callable[[ConverseMessages], Awaitable[int]]:
tool_tokens = default_tokenize_string(json.dumps(params.toolConfig))
system_tokens = default_tokenize_string(json.dumps(params.system))

Expand Down

0 comments on commit ec1b94a

Please sign in to comment.