Skip to content

Commit

Permalink
feat: supported Cohere embedding models (#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Jul 29, 2024
1 parent fc1f4ae commit 7c2a1f6
Show file tree
Hide file tree
Showing 19 changed files with 243 additions and 24 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ The project implements [AI DIAL API](https://epam-rail.com/dial_api) for languag

## Supported models

### Chat completion models

The following models support `POST SERVER_URL/openai/deployments/DEPLOYMENT_NAME/chat/completions` endpoint along with optional support of `/tokenize` and `/truncate_prompt` endpoints:

|Vendor|Model|Deployment name|Modality|`/tokenize`|`/truncate_prompt`|tools/functions support|precise tokenization|
Expand All @@ -30,13 +32,17 @@ The models that support `/truncate_prompt` do also support `max_prompt_tokens` r

Certain model do not support precise tokenization, because the tokenization algorithm is not known. Instead an approximate tokenization algorithm is used. It conservatively counts every byte in UTF-8 encoding of a string as a single token.

### Embedding models

The following models support `SERVER_URL/openai/deployments/DEPLOYMENT_NAME/embeddings` endpoint:

|Model|Deployment name|Modality|
|---|---|---|
|Titan Multimodal Embeddings Generation 1 (G1)|amazon.titan-embed-image-v1|image/text-to-embedding|
|Amazon Titan Text Embeddings V2|amazon.titan-embed-text-v2:0|text-to-embedding|
|Titan Embeddings G1 – Text v1.2|amazon.titan-embed-text-v1|text-to-embedding|
|Cohere Embed English|cohere.embed-english-v3|text-to-embedding|
|Cohere Multilingual|cohere.embed-multilingual-v3|text-to-embedding|

## Developer environment

Expand Down
15 changes: 12 additions & 3 deletions aidial_adapter_bedrock/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from abc import ABC
from logging import DEBUG
from typing import Any, AsyncIterator, Optional
from typing import Any, AsyncIterator, Mapping, Optional, Tuple

import boto3
from botocore.eventstream import EventStream
Expand All @@ -16,6 +16,9 @@
from aidial_adapter_bedrock.utils.json import json_dumps_short
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log

Body = dict
Headers = Mapping[str, str]


class Bedrock:
client: Any
Expand All @@ -38,7 +41,9 @@ def _create_invoke_params(self, model: str, body: dict) -> dict:
"contentType": "application/json",
}

async def ainvoke_non_streaming(self, model: str, args: dict) -> dict:
async def ainvoke_non_streaming(
self, model: str, args: dict
) -> Tuple[Body, Headers]:

if log.isEnabledFor(DEBUG):
log.debug(
Expand All @@ -54,10 +59,14 @@ async def ainvoke_non_streaming(self, model: str, args: dict) -> dict:
body: StreamingBody = response["body"]
body_dict = json.loads(await make_async(lambda: body.read()))

response_headers = response.get("ResponseMetadata", {}).get(
"HTTPHeaders", {}
)

if log.isEnabledFor(DEBUG):
log.debug(f"response['body']: {json_dumps_short(body_dict)}")

return body_dict
return body_dict, response_headers

async def ainvoke_streaming(
self, model: str, args: dict
Expand Down
3 changes: 3 additions & 0 deletions aidial_adapter_bedrock/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class EmbeddingsDeployment(str, Enum):
AMAZON_TITAN_EMBED_TEXT_V2 = "amazon.titan-embed-text-v2:0"
AMAZON_TITAN_EMBED_IMAGE_V1 = "amazon.titan-embed-image-v1"

COHERE_EMBED_ENGLISH_V3 = "cohere.embed-english-v3"
COHERE_EMBED_MULTILINGUAL_V3 = "cohere.embed-multilingual-v3"

@property
def deployment_id(self) -> str:
"""Deployment id under which the model is served by the adapter."""
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ class AmazonResponse(BaseModel):
async def call_embedding_model(
client: Bedrock, model: str, request: dict
) -> Tuple[List[float], int]:
response_dict = await client.ainvoke_non_streaming(model, request)
response_dict, _headers = await client.ainvoke_non_streaming(model, request)
response = AmazonResponse.parse_obj(response_dict)
return response.embedding, response.inputTextTokenCount
10 changes: 8 additions & 2 deletions aidial_adapter_bedrock/embedding/amazon/titan_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
FileStorage,
create_file_storage,
)
from aidial_adapter_bedrock.embedding.amazon.base import call_embedding_model
from aidial_adapter_bedrock.embedding.amazon.response import (
call_embedding_model,
)
from aidial_adapter_bedrock.embedding.attachments import download_base64_data
from aidial_adapter_bedrock.embedding.embeddings_adapter import (
EmbeddingsAdapter,
Expand Down Expand Up @@ -134,7 +136,11 @@ async def embeddings(
self, request: EmbeddingsRequest
) -> EmbeddingsResponse:

validate_embeddings_request(request, supports_dimensions=True)
validate_embeddings_request(
request,
supports_type=False,
supports_dimensions=True,
)

vectors: List[List[float] | str] = []
token_count = 0
Expand Down
8 changes: 6 additions & 2 deletions aidial_adapter_bedrock/embedding/amazon/titan_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
collect_embedding_inputs_without_attachments,
)
from aidial_adapter_bedrock.dial_api.response import make_embeddings_response
from aidial_adapter_bedrock.embedding.amazon.base import call_embedding_model
from aidial_adapter_bedrock.embedding.amazon.response import (
call_embedding_model,
)
from aidial_adapter_bedrock.embedding.embeddings_adapter import (
EmbeddingsAdapter,
)
Expand Down Expand Up @@ -67,7 +69,9 @@ async def embeddings(
) -> EmbeddingsResponse:

validate_embeddings_request(
request, supports_dimensions=self.supports_dimensions
request,
supports_type=False,
supports_dimensions=self.supports_dimensions,
)

vectors: List[List[float] | str] = []
Expand Down
Empty file.
108 changes: 108 additions & 0 deletions aidial_adapter_bedrock/embedding/cohere/embed_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
Text Embeddings Adapter for Cohere Embed model
See the documentation:
https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html
https://docs.cohere.com/reference/embed
"""

from typing import AsyncIterator, List, Self

from aidial_sdk.embeddings import Response as EmbeddingsResponse
from aidial_sdk.embeddings import Usage
from aidial_sdk.embeddings.request import EmbeddingsRequest

from aidial_adapter_bedrock.bedrock import Bedrock
from aidial_adapter_bedrock.dial_api.embedding_inputs import (
EMPTY_INPUT_LIST_ERROR,
collect_embedding_inputs_without_attachments,
)
from aidial_adapter_bedrock.dial_api.response import make_embeddings_response
from aidial_adapter_bedrock.embedding.cohere.response import (
call_embedding_model,
)
from aidial_adapter_bedrock.embedding.embeddings_adapter import (
EmbeddingsAdapter,
)
from aidial_adapter_bedrock.embedding.encoding import vector_to_base64
from aidial_adapter_bedrock.embedding.validation import (
validate_embeddings_request,
)
from aidial_adapter_bedrock.llm.errors import ValidationError
from aidial_adapter_bedrock.utils.json import remove_nones


def create_cohere_request(texts: List[str], input_type: str) -> dict:
return remove_nones(
{
"texts": texts,
"input_type": input_type,
}
)


def get_text_inputs(request: EmbeddingsRequest) -> AsyncIterator[str]:
async def on_texts(texts: List[str]) -> str:
if len(texts) == 0:
raise EMPTY_INPUT_LIST_ERROR
elif len(texts) == 1:
return texts[0]
else:
raise ValidationError(
"No more than one element is allowed in an element of custom_input list"
)

return collect_embedding_inputs_without_attachments(
request, on_texts=on_texts
)


class CohereTextEmbeddings(EmbeddingsAdapter):
model: str
client: Bedrock

@classmethod
def create(cls, client: Bedrock, model: str) -> Self:
return cls(client=client, model=model)

async def embeddings(
self, request: EmbeddingsRequest
) -> EmbeddingsResponse:

validate_embeddings_request(
request,
supports_type=True,
supports_dimensions=False,
)

input_type: str | None = (
request.custom_fields and request.custom_fields.type
)

if input_type is None:
raise ValidationError(
"Embedding type request parameter is required"
)

text_inputs = [txt async for txt in get_text_inputs(request)]

embedding_request = create_cohere_request(text_inputs, input_type)

embeddings, tokens = await call_embedding_model(
self.client, self.model, embedding_request
)

vectors: List[List[float] | str] = [
(
vector_to_base64(embedding)
if request.encoding_format == "base64"
else embedding
)
for embedding in embeddings
]

return make_embeddings_response(
model=self.model,
vectors=vectors,
usage=Usage(prompt_tokens=tokens, total_tokens=tokens),
)
29 changes: 29 additions & 0 deletions aidial_adapter_bedrock/embedding/cohere/response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import List, Literal, Tuple

from pydantic import BaseModel

from aidial_adapter_bedrock.bedrock import Bedrock
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log


class CohereResponse(BaseModel):
id: str
response_type: Literal["embeddings_floats"]
embeddings: List[List[float]]
texts: List[str]
# According to https://docs.cohere.com/reference/embed
# input tokens are expected to be returned in the response field `meta`.
# However, Bedrock moved it to the response headers.


async def call_embedding_model(
client: Bedrock, model: str, request: dict
) -> Tuple[List[List[float]], int]:
body, headers = await client.ainvoke_non_streaming(model, request)
response = CohereResponse.parse_obj(body)

input_tokens = int(headers.get("x-amzn-bedrock-input-token-count", "0"))
if input_tokens == 0:
log.warning("Can't extract input tokens from embeddings response")

return response.embeddings, input_tokens
7 changes: 5 additions & 2 deletions aidial_adapter_bedrock/embedding/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@


def validate_embeddings_request(
request: EmbeddingsRequest, *, supports_dimensions: bool
request: EmbeddingsRequest,
*,
supports_type: bool,
supports_dimensions: bool
) -> None:
if request.dimensions is not None and not supports_dimensions:
raise ValidationError("Dimensions parameter is not supported")
Expand All @@ -13,7 +16,7 @@ def validate_embeddings_request(
if request.custom_fields.instruction is not None:
raise ValidationError("Instruction prompt is not supported")

if request.custom_fields.type is not None:
if request.custom_fields.type is not None and not supports_type:
raise ValidationError(
"The embedding model does not support embedding types"
)
17 changes: 12 additions & 5 deletions aidial_adapter_bedrock/llm/model/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from aidial_adapter_bedrock.embedding.amazon.titan_text import (
AmazonTitanTextEmbeddings,
)
from aidial_adapter_bedrock.embedding.cohere.embed_text import (
CohereTextEmbeddings,
)
from aidial_adapter_bedrock.embedding.embeddings_adapter import (
EmbeddingsAdapter,
)
Expand Down Expand Up @@ -91,18 +94,22 @@ async def get_embeddings_model(
deployment: EmbeddingsDeployment, region: str, api_key: str
) -> EmbeddingsAdapter:
model = deployment.model_id
client = await Bedrock.acreate(region)
match deployment:
case EmbeddingsDeployment.AMAZON_TITAN_EMBED_TEXT_V1:
return AmazonTitanTextEmbeddings.create(
await Bedrock.acreate(region), model, supports_dimensions=False
client, model, supports_dimensions=False
)
case EmbeddingsDeployment.AMAZON_TITAN_EMBED_TEXT_V2:
return AmazonTitanTextEmbeddings.create(
await Bedrock.acreate(region), model, supports_dimensions=True
client, model, supports_dimensions=True
)
case EmbeddingsDeployment.AMAZON_TITAN_EMBED_IMAGE_V1:
return AmazonTitanImageEmbeddings.create(
await Bedrock.acreate(region), model, api_key
)
return AmazonTitanImageEmbeddings.create(client, model, api_key)
case (
EmbeddingsDeployment.COHERE_EMBED_ENGLISH_V3
| EmbeddingsDeployment.COHERE_EMBED_MULTILINGUAL_V3
):
return CohereTextEmbeddings.create(client, model)
case _:
assert_never(deployment)
4 changes: 3 additions & 1 deletion aidial_adapter_bedrock/llm/model/ai21.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ async def predict(
self, consumer: Consumer, params: ModelParameters, prompt: str
):
args = create_request(prompt, convert_params(params))
response = await self.client.ainvoke_non_streaming(self.model, args)
response, _headers = await self.client.ainvoke_non_streaming(
self.model, args
)

resp = AI21Response.parse_obj(response)

Expand Down
4 changes: 3 additions & 1 deletion aidial_adapter_bedrock/llm/model/amazon.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ async def predict(
chunks = self.client.ainvoke_streaming(self.model, args)
stream = chunks_to_stream(chunks, usage)
else:
response = await self.client.ainvoke_non_streaming(self.model, args)
response, _headers = await self.client.ainvoke_non_streaming(
self.model, args
)
stream = response_to_stream(response, usage)

stream = self.post_process_stream(stream, params, self.chat_emulator)
Expand Down
4 changes: 3 additions & 1 deletion aidial_adapter_bedrock/llm/model/claude/v1_v2/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ async def predict(
chunks = self.client.ainvoke_streaming(self.model, args)
stream = chunks_to_stream(chunks)
else:
response = await self.client.ainvoke_non_streaming(self.model, args)
response, _headers = await self.client.ainvoke_non_streaming(
self.model, args
)
stream = response_to_stream(response)

stream = stream_utils.lstrip(stream)
Expand Down
4 changes: 3 additions & 1 deletion aidial_adapter_bedrock/llm/model/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ async def predict(
chunks = self.client.ainvoke_streaming(self.model, args)
stream = chunks_to_stream(chunks, usage)
else:
response = await self.client.ainvoke_non_streaming(self.model, args)
response, _headers = await self.client.ainvoke_non_streaming(
self.model, args
)
stream = response_to_stream(response, usage)

stream = self.post_process_stream(stream, params, self.chat_emulator)
Expand Down
4 changes: 3 additions & 1 deletion aidial_adapter_bedrock/llm/model/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ async def predict(
chunks = self.client.ainvoke_streaming(self.model, args)
stream = chunks_to_stream(chunks, usage)
else:
response = await self.client.ainvoke_non_streaming(self.model, args)
response, _headers = await self.client.ainvoke_non_streaming(
self.model, args
)
stream = response_to_stream(response, usage)

stream = self.post_process_stream(stream, params, self.chat_emulator)
Expand Down
Loading

0 comments on commit 7c2a1f6

Please sign in to comment.