Skip to content

Commit

Permalink
feat: supported Amazon Titan Text Embeddings V2 (#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Jun 14, 2024
1 parent 6b4e437 commit fe3a024
Show file tree
Hide file tree
Showing 20 changed files with 428 additions and 116 deletions.
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ DIAL_URL=<dial core url>

# Misc env vars for the server
LOG_LEVEL=INFO # Default in prod is INFO. Use DEBUG for dev.
WEB_CONCURRENCY=1 # Number of unicorn workers
WEB_CONCURRENCY=1 # Number of uvicorn workers
TEST_SERVER_URL=http://0.0.0.0:5001
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ The models that support `/truncate_prompt` do also support `max_prompt_tokens` r

Certain model do not support precise tokenization, because the tokenization algorithm is not known. Instead an approximate tokenization algorithm is used. It conservatively counts every byte in UTF-8 encoding of a string as a single token.

The following models support `SERVER_URL/openai/deployments/DEPLOYMENT_NAME/embeddings` endpoint:

|Model|Deployment name|Modality|
|---|---|---|
|Amazon Titan Text Embeddings V2|amazon.titan-embed-text-v2:0|text-to-embedding|

## Developer environment

This project uses [Python>=3.11](https://www.python.org/downloads/) and [Poetry>=1.6.1](https://python-poetry.org/) as a dependency manager.
Expand Down
50 changes: 46 additions & 4 deletions aidial_adapter_bedrock/app.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,28 @@
import json
from typing import Optional

from aidial_sdk import DIALApp
from aidial_sdk.telemetry.types import TelemetryConfig
from fastapi import Body, Header, Path

from aidial_adapter_bedrock.chat_completion import BedrockChatCompletion
from aidial_adapter_bedrock.deployments import BedrockDeployment
from aidial_adapter_bedrock.dial_api.response import ModelObject, ModelsResponse
from aidial_adapter_bedrock.deployments import (
ChatCompletionDeployment,
EmbeddingsDeployment,
)
from aidial_adapter_bedrock.dial_api.request import (
EmbeddingsRequest,
EmbeddingsType,
)
from aidial_adapter_bedrock.dial_api.response import (
ModelObject,
ModelsResponse,
make_embeddings_response,
)
from aidial_adapter_bedrock.llm.model.adapter import get_embeddings_model
from aidial_adapter_bedrock.server.exceptions import dial_exception_decorator
from aidial_adapter_bedrock.utils.env import get_aws_default_region
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log
from aidial_adapter_bedrock.utils.log_config import configure_loggers

AWS_DEFAULT_REGION = get_aws_default_region()
Expand All @@ -28,13 +45,38 @@ async def models():
return ModelsResponse(
data=[
ModelObject(id=deployment.deployment_id)
for deployment in BedrockDeployment
for deployment in ChatCompletionDeployment
]
)


for deployment in BedrockDeployment:
for deployment in ChatCompletionDeployment:
app.add_chat_completion(
deployment.deployment_id,
BedrockChatCompletion(region=AWS_DEFAULT_REGION),
)


@app.post("/openai/deployments/{deployment}/embeddings")
@dial_exception_decorator
async def embeddings(
embeddings_type: EmbeddingsType = Header(
alias="X-DIAL-Type", default=EmbeddingsType.SYMMETRIC
),
embeddings_instruction: Optional[str] = Header(
alias="X-DIAL-Instruction", default=None
),
deployment: EmbeddingsDeployment = Path(...),
request: dict = Body(..., examples=[EmbeddingsRequest.example()]),
):
log.debug(f"request: {json.dumps(request)}")

model = await get_embeddings_model(
deployment=deployment, region=AWS_DEFAULT_REGION
)

response = await model.embeddings(
request, embeddings_instruction, embeddings_type
)

return make_embeddings_response(deployment, response)
4 changes: 2 additions & 2 deletions aidial_adapter_bedrock/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def ainvoke_non_streaming(self, model: str, args: dict) -> dict:
body: StreamingBody = response["body"]
body_dict = json.loads(await make_async(lambda: body.read()))

log.debug(f"response['body']: {body_dict}")
log.debug(f"response['body']: {json.dumps(body_dict)}")

return body_dict

Expand All @@ -65,7 +65,7 @@ async def ainvoke_streaming(
chunk = event.get("chunk")
if chunk:
chunk_dict = json.loads(chunk.get("bytes").decode())
log.debug(f"chunk: {chunk_dict}")
log.debug(f"chunk: {json.dumps(chunk_dict)}")
yield chunk_dict


Expand Down
6 changes: 4 additions & 2 deletions aidial_adapter_bedrock/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from typing_extensions import override

from aidial_adapter_bedrock.deployments import BedrockDeployment
from aidial_adapter_bedrock.deployments import ChatCompletionDeployment
from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.chat_model import ChatCompletionAdapter
Expand All @@ -44,7 +44,9 @@ def __init__(self, region: str):
async def _get_model(
self, request: FromRequestDeploymentMixin
) -> ChatCompletionAdapter:
deployment = BedrockDeployment.from_deployment_id(request.deployment_id)
deployment = ChatCompletionDeployment.from_deployment_id(
request.deployment_id
)
return await get_bedrock_adapter(
region=self.region,
deployment=deployment,
Expand Down
26 changes: 22 additions & 4 deletions aidial_adapter_bedrock/deployments.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum


class BedrockDeployment(str, Enum):
class ChatCompletionDeployment(str, Enum):
AMAZON_TITAN_TG1_LARGE = "amazon.titan-tg1-large"
AI21_J2_GRANDE_INSTRUCT = "ai21.j2-grande-instruct"
AI21_J2_JUMBO_INSTRUCT = "ai21.j2-jumbo-instruct"
Expand Down Expand Up @@ -30,11 +30,29 @@ def model_id(self) -> str:
"""Id of the model in the Bedrock service."""

# Redirect Stability model without version to the earliest non-deprecated version (V1)
if self == BedrockDeployment.STABILITY_STABLE_DIFFUSION_XL:
return BedrockDeployment.STABILITY_STABLE_DIFFUSION_XL_V1.model_id
if self == ChatCompletionDeployment.STABILITY_STABLE_DIFFUSION_XL:
return (
ChatCompletionDeployment.STABILITY_STABLE_DIFFUSION_XL_V1.model_id
)

return self.value

@classmethod
def from_deployment_id(cls, deployment_id: str) -> "BedrockDeployment":
def from_deployment_id(
cls, deployment_id: str
) -> "ChatCompletionDeployment":
return cls(deployment_id)


class EmbeddingsDeployment(str, Enum):
AMAZON_TITAN_EMBED_TEXT_2 = "amazon.titan-embed-text-v2:0"

@property
def deployment_id(self) -> str:
"""Deployment id under which the model is served by the adapter."""
return self.value

@property
def model_id(self) -> str:
"""Id of the model in the Bedrock service."""
return self.value
21 changes: 20 additions & 1 deletion aidial_adapter_bedrock/dial_api/request.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional
from enum import Enum
from typing import List, Literal, Optional

from aidial_sdk.chat_completion.request import ChatCompletionRequest
from pydantic import BaseModel
Expand All @@ -7,6 +8,7 @@
ToolsConfig,
validate_messages,
)
from aidial_adapter_bedrock.utils.pydantic import ExtraAllowModel


class ModelParameters(BaseModel):
Expand Down Expand Up @@ -44,3 +46,20 @@ def create(cls, request: ChatCompletionRequest) -> "ModelParameters":

def add_stop_sequences(self, stop: List[str]) -> "ModelParameters":
return self.copy(update={"stop": [*self.stop, *stop]})


class EmbeddingsType(str, Enum):
SYMMETRIC = "symmetric"
DOCUMENT = "document"
QUERY = "query"


class EmbeddingsRequest(ExtraAllowModel):
input: str | List[str]
user: Optional[str] = None
encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None

@staticmethod
def example() -> "EmbeddingsRequest":
return EmbeddingsRequest(input=["fish", "ball"])
43 changes: 42 additions & 1 deletion aidial_adapter_bedrock/dial_api/response.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import List, Literal
from typing import List, Literal, Tuple, TypedDict

from pydantic import BaseModel

from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage


class ModelObject(BaseModel):
object: Literal["model"] = "model"
Expand All @@ -11,3 +13,42 @@ class ModelObject(BaseModel):
class ModelsResponse(BaseModel):
object: Literal["list"] = "list"
data: List[ModelObject]


class EmbeddingsDict(TypedDict):
index: int
object: Literal["embedding"]
embedding: List[float]


class EmbeddingsTokenUsageDict(TypedDict):
prompt_tokens: int
total_tokens: int


class EmbeddingsResponseDict(TypedDict):
object: Literal["list"]
model: str
data: List[EmbeddingsDict]
usage: EmbeddingsTokenUsageDict


def make_embeddings_response(
model_id: str, resp: Tuple[List[List[float]], TokenUsage]
) -> EmbeddingsResponseDict:
vectors, usage = resp

data: List[EmbeddingsDict] = [
{"index": idx, "object": "embedding", "embedding": vec}
for idx, vec in enumerate(vectors)
]

return {
"object": "list",
"model": model_id,
"data": data,
"usage": {
"prompt_tokens": usage.prompt_tokens,
"total_tokens": usage.total_tokens,
},
}
103 changes: 103 additions & 0 deletions aidial_adapter_bedrock/embeddings/amazon_titan_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import Iterable, List, Literal, Optional, Self, Tuple

from pydantic import BaseModel

from aidial_adapter_bedrock.bedrock import Bedrock
from aidial_adapter_bedrock.dial_api.request import (
EmbeddingsRequest,
EmbeddingsType,
)
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.embeddings.embeddings_adapter import (
EmbeddingsAdapter,
)
from aidial_adapter_bedrock.llm.errors import ValidationError
from aidial_adapter_bedrock.utils.json import remove_nones
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log


def validate_parameters(
encoding_format: Literal["float", "base64"],
embedding_type: EmbeddingsType,
embedding_instruction: Optional[str],
supported_embedding_types: List[EmbeddingsType],
) -> None:
if encoding_format == "base64":
raise ValidationError("Base64 encoding format is not supported")

if embedding_instruction is not None:
raise ValidationError("Instruction prompt is not supported")

assert (
len(supported_embedding_types) != 0
), "The embedding model doesn't support any embedding types"

if embedding_type not in supported_embedding_types:
allowed = ", ".join([e.value for e in supported_embedding_types])
raise ValidationError(
f"Embedding types other than {allowed} are not supported"
)


def create_requests(request: EmbeddingsRequest) -> Iterable[dict]:
inputs: List[str] = (
[request.input] if isinstance(request.input, str) else request.input
)

# This includes all Titan-specific request parameters missing
# from the OpenAI Embeddings request, e.g. "normalize" boolean flag.
extra_body = request.get_extra_fields()

# NOTE: Amazon Titan doesn't support batched inputs
for input in inputs:
yield remove_nones(
{
"inputText": input,
"dimensions": request.dimensions,
**extra_body,
}
)


class AmazonResponse(BaseModel):
inputTextTokenCount: int
embedding: List[float]


class AmazonTitanTextEmbeddings(EmbeddingsAdapter):
model: str
client: Bedrock

@classmethod
def create(cls, client: Bedrock, model: str) -> Self:
return cls(client=client, model=model)

async def embeddings(
self,
request_body: dict,
embedding_instruction: Optional[str],
embedding_type: EmbeddingsType,
) -> Tuple[List[List[float]], TokenUsage]:
request = EmbeddingsRequest.parse_obj(request_body)

validate_parameters(
request.encoding_format,
embedding_type,
embedding_instruction,
[EmbeddingsType.SYMMETRIC],
)

embeddings: List[List[float]] = []
usage = TokenUsage()

for request in create_requests(request):
log.debug(f"request: {request}")

response_dict = await self.client.ainvoke_non_streaming(
self.model, request
)
response = AmazonResponse.parse_obj(response_dict)
embeddings.append(response.embedding)
usage.prompt_tokens += response.inputTextTokenCount

return embeddings, usage
21 changes: 21 additions & 0 deletions aidial_adapter_bedrock/embeddings/embeddings_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple

from pydantic import BaseModel

from aidial_adapter_bedrock.dial_api.request import EmbeddingsType
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage


class EmbeddingsAdapter(ABC, BaseModel):
class Config:
arbitrary_types_allowed = True

@abstractmethod
async def embeddings(
self,
request_body: dict,
embedding_instruction: Optional[str],
embedding_type: EmbeddingsType,
) -> Tuple[List[List[float]], TokenUsage]:
pass
Loading

0 comments on commit fe3a024

Please sign in to comment.