From 98a931bf317a8de91baef48fe632bbab81e6b813 Mon Sep 17 00:00:00 2001 From: Anton Dubovik Date: Fri, 10 Nov 2023 15:47:23 +0000 Subject: [PATCH 1/2] feat: migrated latest fixes * Added integration tests for max_tokens and stop sequence * Use number of bytes as token count estimator for AI21 and AWS Titan * Allow empty messages in each language model * Supported history truncation via max_prompt_tokens/discarded_messages parameters * Bumped version of aidial SDK * Removed 'Assistant' prefix occasionally generated by Titan * Supported streaming for Titan and Claude --- .vscode/settings.json | 4 +- MODEL_CARD.md | 129 ------- Makefile | 9 +- README.md | 12 +- aidial_adapter_bedrock/app.py | 32 +- aidial_adapter_bedrock/chat_completion.py | 61 +-- aidial_adapter_bedrock/llm/bedrock_adapter.py | 363 ------------------ .../llm/chat_emulation/claude.py | 55 --- .../llm/chat_emulation/claude_chat.py | 90 +++++ .../llm/chat_emulation/history.py | 31 ++ .../llm/chat_emulation/meta_chat.py | 55 --- .../llm/chat_emulation/pseudo_chat.py | 136 +++++++ .../llm/chat_emulation/types.py | 6 - .../llm/chat_emulation/zero_memory.py | 10 - .../llm/chat_emulation/zero_memory_chat.py | 33 ++ aidial_adapter_bedrock/llm/chat_model.py | 138 ++++--- aidial_adapter_bedrock/llm/consumer.py | 82 ++++ aidial_adapter_bedrock/llm/model/__init__.py | 0 aidial_adapter_bedrock/llm/model/adapter.py | 40 ++ aidial_adapter_bedrock/llm/model/ai21.py | 136 +++++++ aidial_adapter_bedrock/llm/model/amazon.py | 196 ++++++++++ aidial_adapter_bedrock/llm/model/anthropic.py | 140 +++++++ aidial_adapter_bedrock/llm/model/conf.py | 3 + aidial_adapter_bedrock/llm/model/stability.py | 123 ++++++ aidial_adapter_bedrock/llm/model_listing.py | 14 + .../universal_api/request.py | 17 + .../universal_api/token_usage.py | 10 +- aidial_adapter_bedrock/utils/list.py | 9 + aidial_adapter_bedrock/utils/log_config.py | 6 +- aidial_adapter_bedrock/utils/operators.py | 9 - aidial_adapter_bedrock/utils/stream.py | 78 ++++ aidial_adapter_bedrock/utils/text.py | 22 -- client/client_bedrock.py | 22 +- client/utils/cli.py | 14 +- client/utils/input.py | 2 +- poetry.lock | 168 +------- pyproject.toml | 3 +- scripts/find_token_limits.py | 58 --- .../integration_tests/test_chat_completion.py | 85 +++- tests/unit_tests/chat_emulation/__init__.py | 0 .../test_claude_chat_history.py | 123 ++++++ .../test_pseudo_chat_history.py | 223 +++++++++++ .../test_zero_memory_chat_history.py | 33 ++ tests/unit_tests/test_stream.py | 143 +++++++ tests/utils/llm.py | 20 +- tests/utils/string.py | 19 + 46 files changed, 1956 insertions(+), 1006 deletions(-) delete mode 100644 MODEL_CARD.md delete mode 100644 aidial_adapter_bedrock/llm/bedrock_adapter.py delete mode 100644 aidial_adapter_bedrock/llm/chat_emulation/claude.py create mode 100644 aidial_adapter_bedrock/llm/chat_emulation/claude_chat.py create mode 100644 aidial_adapter_bedrock/llm/chat_emulation/history.py delete mode 100644 aidial_adapter_bedrock/llm/chat_emulation/meta_chat.py create mode 100644 aidial_adapter_bedrock/llm/chat_emulation/pseudo_chat.py delete mode 100644 aidial_adapter_bedrock/llm/chat_emulation/types.py delete mode 100644 aidial_adapter_bedrock/llm/chat_emulation/zero_memory.py create mode 100644 aidial_adapter_bedrock/llm/chat_emulation/zero_memory_chat.py create mode 100644 aidial_adapter_bedrock/llm/consumer.py create mode 100644 aidial_adapter_bedrock/llm/model/__init__.py create mode 100644 aidial_adapter_bedrock/llm/model/adapter.py create mode 100644 aidial_adapter_bedrock/llm/model/ai21.py create mode 100644 aidial_adapter_bedrock/llm/model/amazon.py create mode 100644 aidial_adapter_bedrock/llm/model/anthropic.py create mode 100644 aidial_adapter_bedrock/llm/model/conf.py create mode 100644 aidial_adapter_bedrock/llm/model/stability.py create mode 100644 aidial_adapter_bedrock/llm/model_listing.py create mode 100644 aidial_adapter_bedrock/utils/list.py delete mode 100644 aidial_adapter_bedrock/utils/operators.py create mode 100644 aidial_adapter_bedrock/utils/stream.py delete mode 100644 aidial_adapter_bedrock/utils/text.py delete mode 100755 scripts/find_token_limits.py create mode 100644 tests/unit_tests/chat_emulation/__init__.py create mode 100644 tests/unit_tests/chat_emulation/test_claude_chat_history.py create mode 100644 tests/unit_tests/chat_emulation/test_pseudo_chat_history.py create mode 100644 tests/unit_tests/chat_emulation/test_zero_memory_chat_history.py create mode 100644 tests/unit_tests/test_stream.py create mode 100644 tests/utils/string.py diff --git a/.vscode/settings.json b/.vscode/settings.json index 7f02da0f..65349858 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,7 +7,9 @@ }, "editor.tabSize": 4 }, - "python.testing.pytestArgs": ["."], + "python.testing.pytestArgs": [ + "." + ], "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, "python.analysis.typeCheckingMode": "basic" diff --git a/MODEL_CARD.md b/MODEL_CARD.md deleted file mode 100644 index 76524332..00000000 --- a/MODEL_CARD.md +++ /dev/null @@ -1,129 +0,0 @@ -# Bedrock models - -## Amazon Titan - -### Tokenization - -Returns number of used tokens in response: - -```json -{ - "inputTextTokenCount": 10, - "results": [ - { - "content": "foo", - "tokenCount": 20 - } - ] -} -``` - -```python -return TokenUsage( - prompt_tokens=resp.inputTextTokenCount, - completion_tokens=resp.results[0].tokenCount, -) -``` - -Tokenizer is unknown. - -### Token limits - -Tokens limits are unknown. - -Experimentally found tokens limits (see `./test/find_token_limits.py`): - -- amazon.titan-tg1-large: 4096 - -## Anthropic Claude - -### Tokenization - -`anthropic` package provides methods to [calculate tokens](https://github.com/anthropics/anthropic-sdk-python/blob/main/examples/tokens.py). - -```python -from anthropic.tokenizer import count_tokens - -return TokenUsage( - prompt_tokens=count_tokens(query.prompt), - completion_tokens=count_tokens(resp.completion), -) -``` - -The tokenizer could be migrated to the frontend. - -### Streaming - -The streaming is supported by [Anthropic SDK](https://github.com/anthropics/anthropic-sdk-python/blob/main/examples/streaming.py), but it's not supported by the Bedrock API. - -### Token limits - -As per [documentation](https://docs.anthropic.com/claude/docs/introduction-to-prompt-design#prompt-length) the limits aren't strictly specified: - -``` -The maximum prompt length that Claude can see is its context window. Claude's context window is currently ~75,000 words / ~100,000 tokens / ~340,000 Unicode characters. - -Right now when this context window is exceeded in the API Claude is likely to return an incoherent response. We apologize for this “sharp edge”. -``` - -However [pricing page](https://www.anthropic.com/pricing) explicitly says that context windows is 100k tokens. - -So: - -- max input tokens: 100k -- max outputs tokens: unknown - -However, it doesn't match with experimentally found tokens limits (see `./test/find_token_limits.py`): - -- anthropic.claude-v1: 12288 -- anthropic.claude-instant-v1: 12288 - -REST: [completion call](https://docs.anthropic.com/claude/reference/complete_post) - -## AI21 models - -### Rate limits - -See [rate limits](https://docs.ai21.com/docs/rate-limits). - -### Token limits - -REST (and Bedrock presumably): max context window = [8191 tokens](https://docs.ai21.com/reference/j2-complete-ref) - -AWS: see [token limits](https://docs.ai21.com/docs/choosing-the-right-instance-type-for-amazon-sagemaker-models#foundation-models) for various model instances - -``` -The context window acts as a threshold for the amount of tokens in the prompt and the completion, namely: prompt + completion <= context window. -``` - -Experimentally found tokens limits confirm the documentation (see `./test/find_token_limits.py`): - -- ai21.j2-grande-instruct: 8191 -- ai21.j2-jumbo-instruct: 8191 - -### Tokenization - -Response contains tokens explicitly as arrays. One needs only compute the len of the arrays: - -```python -return TokenUsage( - prompt_tokens=len(resp.prompt.tokens), - completion_tokens=len(resp.completions[0].data.tokens), -) -``` - -Tokenizer is [unknown](https://docs.ai21.com/docs/tokenizer-tokenization): - -> AI21 Studio uses a large token dictionary (250K) - -Token counting is only possible using AI21 API Key. - -**Tokenization via Bedrock API is currently unsupported** (see `./local_client_ai21.py`). - -SDK: `ai21` package calls `tokenize` endpoint to do tokenization (see `ai21/modules/tokenization.py`). - -REST: There is API for [tokenization](https://docs.ai21.com/reference/tokenize-ref) - -## Stable diffusion - -There is no meaningful size limit on the prompt size. diff --git a/Makefile b/Makefile index 3d3af8d8..6d8ae63b 100644 --- a/Makefile +++ b/Makefile @@ -2,6 +2,7 @@ PORT ?= 5001 IMAGE_NAME ?= ai-dial-adapter-bedrock PLATFORM ?= linux/amd64 DEV_PYTHON ?= 3.11 +DOCKER ?= docker ARGS= .PHONY: all install build serve clean lint format test integration_tests docker_build docker_run @@ -35,12 +36,12 @@ integration_tests: install poetry run nox -s integration_tests docker_test: - docker build --platform $(PLATFORM) -f Dockerfile.test -t $(IMAGE_NAME):test . - docker run --platform $(PLATFORM) --rm $(IMAGE_NAME):test + $(DOCKER) build --platform $(PLATFORM) -f Dockerfile.test -t $(IMAGE_NAME):test . + $(DOCKER) run --platform $(PLATFORM) --rm $(IMAGE_NAME):test docker_serve: - docker build --platform $(PLATFORM) -t $(IMAGE_NAME):dev . - docker run --platform $(PLATFORM) --env-file ./.env --rm -p $(PORT):5000 $(IMAGE_NAME):dev + $(DOCKER) build --platform $(PLATFORM) -t $(IMAGE_NAME):dev . + $(DOCKER) run --platform $(PLATFORM) --env-file ./.env --rm -p $(PORT):5000 $(IMAGE_NAME):dev help: @echo '====================' diff --git a/README.md b/README.md index adea50ff..6556680e 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,18 @@ The project implements [AI DIAL API](https://epam-rail.com/dial_api) for languag Supported models: * Amazon Titan + - amazon.titan-tg1-large * AI21 J2 -* Anthropic Claude V1, V2 + - ai21.j2-grande-instruct + - ai21.j2-jumbo-instruct + - ai21.j2-mid + - ai21.j2-ultra +* Anthropic Claude + - anthropic.claude-instant-v1 + - anthropic.claude-v1 + - anthropic.claude-v2 * Stable Diffusion + - stability.stable-diffusion-xl ## Developer environment @@ -54,6 +63,7 @@ Copy `.env.example` to `.env` and customize it for your environment: |AWS_SECRET_ACCESS_KEY|NA|AWS credentials with access to Bedrock service| |DEFAULT_REGION||AWS region e.g. "us-east-1"| |LOG_LEVEL|INFO|Log level. Use DEBUG for dev purposes and INFO in prod| +|AIDIAL_LOG_LEVEL|WARNING|AI DIAL SDK log level| |WEB_CONCURRENCY|1|Number of workers for the server| |TEST_SERVER_URL|http://0.0.0.0:5001|Server URL used in the integration tests| diff --git a/aidial_adapter_bedrock/app.py b/aidial_adapter_bedrock/app.py index e667f84b..aea0b7d0 100644 --- a/aidial_adapter_bedrock/app.py +++ b/aidial_adapter_bedrock/app.py @@ -2,11 +2,13 @@ import fastapi from aidial_sdk import DIALApp +from aidial_sdk import HTTPException as DialException +from fastapi import Request +from fastapi.responses import JSONResponse from aidial_adapter_bedrock.chat_completion import BedrockChatCompletion -from aidial_adapter_bedrock.llm.bedrock_adapter import BedrockModels from aidial_adapter_bedrock.llm.bedrock_models import BedrockDeployment -from aidial_adapter_bedrock.llm.chat_emulation.types import ChatEmulationType +from aidial_adapter_bedrock.llm.model_listing import get_bedrock_models from aidial_adapter_bedrock.server.exceptions import dial_exception_decorator from aidial_adapter_bedrock.universal_api.response import ( ModelObject, @@ -14,12 +16,11 @@ ) from aidial_adapter_bedrock.utils.env import get_env from aidial_adapter_bedrock.utils.log_config import LogConfig +from aidial_adapter_bedrock.utils.log_config import app_logger as log logging.config.dictConfig(LogConfig().dict()) default_region = get_env("DEFAULT_REGION") -default_chat_emulation_type = ChatEmulationType.META_CHAT - app = DIALApp(description="AWS Bedrock adapter for RAIL API") @@ -32,7 +33,7 @@ def healthcheck(): @app.get("/openai/models") @dial_exception_decorator async def models(): - bedrock_models = BedrockModels(region=default_region).models() + bedrock_models = get_bedrock_models(region=default_region) models = [ModelObject(id=model["modelId"]) for model in bedrock_models] return ModelsResponse(data=models) @@ -40,8 +41,21 @@ async def models(): for deployment in BedrockDeployment: app.add_chat_completion( deployment.get_model_id(), - BedrockChatCompletion( - region=default_region, - chat_emulation_type=default_chat_emulation_type, - ), + BedrockChatCompletion(region=default_region), + ) + + +@app.exception_handler(DialException) +async def exception_handler(request: Request, exc: DialException): + log.exception(f"Exception: {str(exc)}") + return JSONResponse( + status_code=exc.status_code, + content={ + "error": { + "message": exc.message, + "type": exc.type, + "code": exc.code, + "param": exc.param, + } + }, ) diff --git a/aidial_adapter_bedrock/chat_completion.py b/aidial_adapter_bedrock/chat_completion.py index 4ea14d20..182a05d5 100644 --- a/aidial_adapter_bedrock/chat_completion.py +++ b/aidial_adapter_bedrock/chat_completion.py @@ -1,51 +1,58 @@ import asyncio -from typing import List +from typing import Optional, Set from aidial_sdk.chat_completion import ChatCompletion, Request, Response -from aidial_adapter_bedrock.llm.bedrock_adapter import BedrockAdapter -from aidial_adapter_bedrock.llm.chat_emulation.types import ChatEmulationType +from aidial_adapter_bedrock.llm.consumer import ChoiceConsumer +from aidial_adapter_bedrock.llm.model.adapter import get_bedrock_adapter from aidial_adapter_bedrock.server.exceptions import dial_exception_decorator from aidial_adapter_bedrock.universal_api.request import ModelParameters from aidial_adapter_bedrock.universal_api.token_usage import TokenUsage +from aidial_adapter_bedrock.utils.log_config import app_logger as log class BedrockChatCompletion(ChatCompletion): region: str - chat_emulation_type: ChatEmulationType - def __init__(self, region: str, chat_emulation_type: ChatEmulationType): + def __init__(self, region: str): self.region = region - self.chat_emulation_type = chat_emulation_type @dial_exception_decorator async def chat_completion(self, request: Request, response: Response): - model = await BedrockAdapter.create( + model_params = ModelParameters.create(request) + model = await get_bedrock_adapter( region=self.region, model_id=request.deployment_id, - model_params=ModelParameters.create(request), ) - async def generate_response(idx: int) -> TokenUsage: - model_response = await model.achat( - self.chat_emulation_type, request.messages - ) - + async def generate_response( + usage: TokenUsage, + discarded_messages_set: Set[Optional[int]], + choice_idx: int, + ) -> None: with response.create_choice() as choice: - choice.append_content(model_response.content) - - for data in model_response.data: - choice.add_attachment( - title=data.name, - data=data.content, - type=data.mime_type, - ) - - return model_response.usage - - usages: List[TokenUsage] = await asyncio.gather( - *(generate_response(idx) for idx in range(request.n or 1)) + consumer = ChoiceConsumer(choice) + await model.achat(consumer, model_params, request.messages) + usage.accumulate(consumer.usage) + discarded_messages_set.add(consumer.discarded_messages) + + usage = TokenUsage() + discarded_messages_set: Set[Optional[int]] = set() + + await asyncio.gather( + *( + generate_response(usage, discarded_messages_set, idx) + for idx in range(request.n or 1) + ) ) - usage = sum(usages, TokenUsage()) + log.debug(f"usage: {usage}") response.set_usage(usage.prompt_tokens, usage.completion_tokens) + + assert ( + len(discarded_messages_set) == 1 + ), "Discarded messages count must be the same for each choice." + + discarded_messages = next(iter(discarded_messages_set)) + if discarded_messages is not None: + response.set_discarded_messages(discarded_messages) diff --git a/aidial_adapter_bedrock/llm/bedrock_adapter.py b/aidial_adapter_bedrock/llm/bedrock_adapter.py deleted file mode 100644 index 0f757d47..00000000 --- a/aidial_adapter_bedrock/llm/bedrock_adapter.py +++ /dev/null @@ -1,363 +0,0 @@ -import json -from abc import ABC, abstractmethod -from enum import Enum -from typing import Any, Dict, List, Literal, Optional, TypedDict, Union - -import boto3 -from anthropic.tokenizer import count_tokens -from pydantic import BaseModel, Field -from typing_extensions import Annotated - -from aidial_adapter_bedrock.llm.chat_model import ( - ChatModel, - Model, - ModelResponse, - ResponseData, - TokenUsage, -) -from aidial_adapter_bedrock.universal_api.request import ModelParameters -from aidial_adapter_bedrock.utils.concurrency import make_async -from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log - - -class BedrockModelId(TypedDict): - modelArn: str - modelId: str - - -class IOutput(ABC): - @abstractmethod - def content(self) -> str: - pass - - def data(self) -> list[ResponseData]: - return [] - - @abstractmethod - def usage(self, prompt: str) -> TokenUsage: - pass - - -class AmazonResult(BaseModel): - tokenCount: int - outputText: str - completionReason: Optional[str] - - -class AmazonResponse(BaseModel, IOutput): - inputTextTokenCount: int - results: List[AmazonResult] - - def content(self) -> str: - assert ( - len(self.results) == 1 - ), "AmazonResponse should only have one result" - return self.results[0].outputText - - def usage(self, prompt: str) -> TokenUsage: - assert ( - len(self.results) == 1 - ), "AmazonResponse should only have one result" - return TokenUsage( - prompt_tokens=self.inputTextTokenCount, - completion_tokens=self.results[0].tokenCount, - ) - - -class AnthropicResponse(BaseModel, IOutput): - completion: str - stop_reason: str # Literal["stop_sequence"] - - def content(self) -> str: - return self.completion - - def usage(self, prompt: str) -> TokenUsage: - return TokenUsage( - prompt_tokens=count_tokens(prompt), - completion_tokens=count_tokens(self.completion), - ) - - -class TextRange(BaseModel): - start: int - end: int - - -class GeneratedToken(BaseModel): - token: str - logprob: float - raw_logprob: float - - -class Token(BaseModel): - generatedToken: GeneratedToken - topTokens: Optional[Any] - textRange: TextRange - - -class TextAndTokens(BaseModel): - text: str - tokens: List[Token] - - -class FinishReason(BaseModel): - reason: str # Literal["length", "endoftext"] - length: Optional[int] - - -class Completion(BaseModel): - data: TextAndTokens - finishReason: FinishReason - - -class AI21Response(BaseModel, IOutput): - id: int - prompt: TextAndTokens - completions: List[Completion] - - def content(self) -> str: - assert ( - len(self.completions) == 1 - ), "AI21Response should only have one completion" - return self.completions[0].data.text - - def usage(self, prompt: str) -> TokenUsage: - assert ( - len(self.completions) == 1 - ), "AI21Response should only have one completion" - return TokenUsage( - prompt_tokens=len(self.prompt.tokens), - completion_tokens=len(self.completions[0].data.tokens), - ) - - -class StabilityStatus(str, Enum): - SUCCESS = "success" - ERROR = "error" - - -class StabilityError(BaseModel): - id: str - message: str - name: str - - -class StabilityArtifact(BaseModel): - seed: int - base64: str - finish_reason: str = Field(alias="finishReason") - - -class StabilityResponse(BaseModel, IOutput): - # TODO: Use tagged union artifacts/error - result: str - artifacts: Optional[list[StabilityArtifact]] - error: Optional[StabilityError] - - def content(self) -> str: - self._throw_if_error() - return "" - - def data(self) -> list[ResponseData]: - self._throw_if_error() - return [ResponseData(mime_type="image/png", name="image", content=self.artifacts[0].base64)] # type: ignore - - def usage(self, prompt: str) -> TokenUsage: - return TokenUsage( - prompt_tokens=0, - completion_tokens=0, - ) - - def _throw_if_error(self): - if self.result == StabilityStatus.ERROR: - raise Exception(self.error.message) # type: ignore - - -class TaggedAmazonResponse(AmazonResponse): - provider: Literal["amazon"] - - -class TaggedAnthropicResponse(AnthropicResponse): - provider: Literal["anthropic"] - - -class TaggedAI21Response(AI21Response): - provider: Literal["ai21"] - - -class TaggedStabilityResponse(StabilityResponse): - provider: Literal["stability"] - - -class BedrockResponse(BaseModel, IOutput): - __root__: Annotated[ - Union[ - TaggedAmazonResponse, - TaggedAnthropicResponse, - TaggedAI21Response, - TaggedStabilityResponse, - ], - Field(discriminator="provider"), - ] - - def content(self) -> str: - return self.__root__.content() - - def data(self) -> list[ResponseData]: - return self.__root__.data() - - def usage(self, prompt: str) -> TokenUsage: - return self.__root__.usage(prompt) - - -class BedrockModels: - def __init__(self, region: str): - session = boto3.Session() - self.bedrock = session.client("bedrock", region) - - def models(self) -> List[BedrockModelId]: - return self.bedrock.list_foundation_models()["modelSummaries"] - - -# Simplified copy of langchain.llms.bedrock.LLMInputOutputAdapter.prepare_input -def prepare_input( - provider: str, prompt: str, model_kwargs: Dict[str, Any] -) -> Dict[str, Any]: - input_body = {**model_kwargs} - if provider == "anthropic" or provider == "ai21": - input_body["prompt"] = prompt - elif provider == "amazon": - input_body = dict() - input_body["inputText"] = prompt - input_body["textGenerationConfig"] = {**model_kwargs} - elif provider == "stability": - input_body = dict() - input_body["text_prompts"] = [{"text": prompt}] - else: - input_body["inputText"] = prompt - - return input_body - - -def prepare_model_kwargs( - provider: str, model_params: ModelParameters -) -> Dict[str, Any]: - model_kwargs = {} - - # NOTE: See https://docs.anthropic.com/claude/reference/complete_post - if provider == "anthropic": - if model_params.max_tokens is not None: - model_kwargs["max_tokens_to_sample"] = model_params.max_tokens - else: - # The max tokens parameter is required for Anthropic models. - # Choosing reasonable default. - model_kwargs["max_tokens_to_sample"] = 500 - - if model_params.stop is not None: - model_kwargs["stop_sequences"] = model_params.stop - - if model_params.temperature is not None: - model_kwargs["temperature"] = model_params.temperature - - # Doesn't have any effect. AWS always sends the whole response at once. - # streaming = model_kwargs.streaming - - if model_params.top_p is not None: - model_kwargs["top_p"] = model_params.top_p - - # OpenAI API doesn't have top_k parameter. - # if model_params.top_k is not None: - # model_kwargs["top_k"] = model_params.top_k - - # NOTE: API See https://docs.ai21.com/reference/j2-instruct-ref - # NOTE: Per-model token limits: https://docs.ai21.com/docs/choosing-the-right-instance-type-for-amazon-sagemaker-models#foundation-models - if provider == "ai21": - if model_params.max_tokens is not None: - model_kwargs["maxTokens"] = model_params.max_tokens - else: - # The default for max tokens is 16, which is too small for most use cases - model_kwargs["maxTokens"] = 500 - - if model_params.temperature is not None: - model_kwargs["temperature"] = model_params.temperature - else: - # The default AI21 temperature is 0.7. - # The default OpenAI temperature is 1.0. - # Choosing the OpenAI default since we pretend AI21 to be OpenAI. - model_kwargs["temperature"] = 1.0 - - if model_params.top_p is not None: - model_kwargs["topP"] = model_params.top_p - - if model_params.stop is not None: - model_kwargs["stopSequences"] = model_params.stop - - # NOTE: AI21 has "numResults" parameter, however we emulate multiple result - # via mutliple calls to support all models uniformly. - - if provider == "amazon": - if model_params.temperature is not None: - model_kwargs["temperature"] = model_params.temperature - # NOTE: There is no documentation for Amazon models currently. - # NOTE: max tokens is 128 by default. The parameter name is not known. - - return model_kwargs - - -class BedrockAdapter(ChatModel): - def __init__( - self, - model_id: str, - model_provider: str, - model_params: ModelParameters, - model_kwargs: Dict[str, Any], - bedrock: Any, - ): - self.model_id = model_id - self.model_provider = model_provider - self.model_params = model_params - self.model_kwargs = model_kwargs - self.bedrock = bedrock - - @classmethod - async def create( - cls, model_id: str, region: str, model_params: ModelParameters - ) -> "BedrockAdapter": - model_provider = Model.parse(model_id).provider - - model_kwargs = prepare_model_kwargs(model_provider, model_params) - - bedrock = await make_async( - lambda _: boto3.Session().client("bedrock-runtime", region), - (), - ) - - return cls( - model_id, model_provider, model_params, model_kwargs, bedrock - ) - - async def acall(self, prompt: str) -> ModelResponse: - return await make_async(self._call, prompt) - - def _call(self, prompt: str) -> ModelResponse: - log.debug(f"prompt:\n{prompt}") - - model_response = self.bedrock.invoke_model( - body=json.dumps( - prepare_input(self.model_provider, prompt, self.model_kwargs) - ), - modelId=self.model_id, - contentType="application/json", - accept="application/json", - ) - - body = json.loads(model_response["body"].read()) - resp = BedrockResponse.parse_obj( - {"provider": self.model_provider, **body} - ) - response = ModelResponse( - content=resp.content(), data=resp.data(), usage=resp.usage(prompt) - ) - - log.debug(f"response:\n{response.json()}") - return response diff --git a/aidial_adapter_bedrock/llm/chat_emulation/claude.py b/aidial_adapter_bedrock/llm/chat_emulation/claude.py deleted file mode 100644 index 5bb5c8a5..00000000 --- a/aidial_adapter_bedrock/llm/chat_emulation/claude.py +++ /dev/null @@ -1,55 +0,0 @@ -from enum import Enum -from typing import List - -from pydantic import BaseModel - -from aidial_adapter_bedrock.llm.exceptions import ValidationError -from aidial_adapter_bedrock.llm.message import BaseMessage - - -class ClaudeRole(Enum): - HUMAN = "Human" - ASSISTANT = "Assistant" - - -class ClaudeMessage(BaseModel): - role: ClaudeRole - content: str - - -class ClaudeHistory: - history: List[ClaudeMessage] - - def __init__(self): - self.history = [] - - def add(self, msg: ClaudeMessage): - if len(self.history) > 0 and self.history[-1].role == msg.role: - self.history[-1].content += " " + msg.content - else: - self.history.append(msg) - - def print(self) -> str: - return "".join( - [ - f"\n\n{msg.role.value}: {msg.content.lstrip()}".rstrip() - for msg in self.history - ] - ) - - -def emulate(prompt: List[BaseMessage]) -> str: - if len(prompt) == 0: - raise ValidationError("List of messages must not be empty") - - history = ClaudeHistory() - for msg in prompt: - role = ( - ClaudeRole.HUMAN - if msg.type in ["system", "human"] - else ClaudeRole.ASSISTANT - ) - history.add(ClaudeMessage(role=role, content=msg.content)) - history.add(ClaudeMessage(role=ClaudeRole.ASSISTANT, content="")) - - return history.print() diff --git a/aidial_adapter_bedrock/llm/chat_emulation/claude_chat.py b/aidial_adapter_bedrock/llm/chat_emulation/claude_chat.py new file mode 100644 index 00000000..be959c81 --- /dev/null +++ b/aidial_adapter_bedrock/llm/chat_emulation/claude_chat.py @@ -0,0 +1,90 @@ +from enum import Enum +from typing import Callable, List, Set, Tuple + +import anthropic + +from aidial_adapter_bedrock.llm.chat_emulation.history import ( + FormattedMessage, + History, + is_important_message, +) +from aidial_adapter_bedrock.llm.exceptions import ValidationError +from aidial_adapter_bedrock.llm.message import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, +) +from aidial_adapter_bedrock.utils.list import exclude_indices + + +class RolePrompt(str, Enum): + HUMAN = anthropic.HUMAN_PROMPT + AI = anthropic.AI_PROMPT + + +STOP_SEQUENCES: List[str] = [anthropic.HUMAN_PROMPT] + + +def _format_message(message: BaseMessage) -> str: + role = ( + RolePrompt.HUMAN + if isinstance(message, (SystemMessage, HumanMessage)) + else RolePrompt.AI + ) + return (role + " " + message.content.lstrip()).rstrip() + + +class ClaudeChatHistory(History): + def trim( + self, + count_tokens: Callable[[str], int], + max_prompt_tokens: int, + ) -> Tuple["ClaudeChatHistory", int]: + message_tokens = [ + count_tokens(message.text) for message in self.messages + ] + prompt_tokens = sum(message_tokens) + if prompt_tokens <= max_prompt_tokens: + return self, 0 + + discarded_messages: Set[int] = set() + for index, message in enumerate(self.messages): + if message.is_important: + continue + + discarded_messages.add(index) + prompt_tokens -= message_tokens[index] + if prompt_tokens <= max_prompt_tokens: + return ClaudeChatHistory( + messages=exclude_indices(self.messages, discarded_messages) + ), len(discarded_messages) + + if discarded_messages: + raise ValidationError( + f"The token size of system messages and the last user message ({prompt_tokens}) exceeds" + f" prompt token limit ({max_prompt_tokens})." + ) + + raise ValidationError( + f"Prompt token size ({prompt_tokens}) exceeds prompt token limit ({max_prompt_tokens})." + ) + + @classmethod + def create(cls, messages: List[BaseMessage]) -> "ClaudeChatHistory": + formatted_messages = [] + + for index, message in enumerate(messages): + formatted_messages.append( + FormattedMessage( + text=_format_message(message), + source_message=message, + is_important=is_important_message(messages, index), + ) + ) + + formatted_messages.append( + FormattedMessage(text=_format_message(AIMessage(content=""))) + ) + + return cls(messages=formatted_messages) diff --git a/aidial_adapter_bedrock/llm/chat_emulation/history.py b/aidial_adapter_bedrock/llm/chat_emulation/history.py new file mode 100644 index 00000000..e8f01663 --- /dev/null +++ b/aidial_adapter_bedrock/llm/chat_emulation/history.py @@ -0,0 +1,31 @@ +from abc import ABC, abstractmethod +from typing import Callable, List, Optional, Tuple + +from pydantic.main import BaseModel + +from aidial_adapter_bedrock.llm.message import BaseMessage, SystemMessage + + +def is_important_message(messages: List[BaseMessage], index: int) -> bool: + return ( + isinstance(messages[index], SystemMessage) or index == len(messages) - 1 + ) + + +class FormattedMessage(BaseModel): + text: str + source_message: Optional[BaseMessage] = None + is_important: bool = True + + +class History(ABC, BaseModel): + messages: List[FormattedMessage] + + def format(self) -> str: + return "".join(message.text for message in self.messages) + + @abstractmethod + def trim( + self, count_tokens: Callable[[str], int], max_prompt_tokens: int + ) -> Tuple["History", int]: + pass diff --git a/aidial_adapter_bedrock/llm/chat_emulation/meta_chat.py b/aidial_adapter_bedrock/llm/chat_emulation/meta_chat.py deleted file mode 100644 index 20b53f9d..00000000 --- a/aidial_adapter_bedrock/llm/chat_emulation/meta_chat.py +++ /dev/null @@ -1,55 +0,0 @@ -from typing import List, Tuple - -from aidial_adapter_bedrock.llm.exceptions import ValidationError -from aidial_adapter_bedrock.llm.message import ( - AIMessage, - BaseMessage, - SystemMessage, -) -from aidial_adapter_bedrock.utils.operators import Unary -from aidial_adapter_bedrock.utils.text import enforce_stop_tokens, remove_prefix - -HUMAN = "Human" -ASSISTANT = "Assistant" -SYSTEM = "System" - -prelude = f""" -You are a helpful assistant participating in a dialog with a user. -The messages from the user start with "{HUMAN}:". -The messages from you start with "{ASSISTANT}:". -Reply to the last message from the user taking into account the preceding dialog history. -==================== -""".strip() - - -def type_to_role(ty: str) -> str: - roles = {"human": HUMAN, "system": HUMAN, "ai": ASSISTANT} - return roles.get(ty, ty) - - -def emulate(prompt: List[BaseMessage]) -> Tuple[str, Unary[str]]: - if len(prompt) == 0: - raise ValidationError("List of messages must not be empty") - - history = prompt.copy() - history.append(AIMessage(content="")) - - msgs = [prelude] - for msg in history: - # Skipping empty system messages - if isinstance(msg, SystemMessage) and msg.content.strip() == "": - continue - - role = type_to_role(msg.type) - msgs.append(f"\n\n{role}: {msg.content.lstrip()}".rstrip()) - - return "".join(msgs), post_process - - -stop = f"{HUMAN}:" - - -def post_process(response: str) -> str: - response = enforce_stop_tokens(response, [stop]) - response = remove_prefix(response.strip(), f"{ASSISTANT}:") - return response.strip() diff --git a/aidial_adapter_bedrock/llm/chat_emulation/pseudo_chat.py b/aidial_adapter_bedrock/llm/chat_emulation/pseudo_chat.py new file mode 100644 index 00000000..b2ea5302 --- /dev/null +++ b/aidial_adapter_bedrock/llm/chat_emulation/pseudo_chat.py @@ -0,0 +1,136 @@ +from enum import Enum +from typing import Callable, List, Optional, Set, Tuple + +from aidial_adapter_bedrock.llm.chat_emulation.history import ( + FormattedMessage, + History, + is_important_message, +) +from aidial_adapter_bedrock.llm.exceptions import ValidationError +from aidial_adapter_bedrock.llm.message import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, +) +from aidial_adapter_bedrock.utils.list import exclude_indices + + +class RolePrompt(str, Enum): + HUMAN = "\n\nHuman:" + ASSISTANT = "\n\nAssistant:" + + +STOP_SEQUENCES: List[str] = [RolePrompt.HUMAN] + + +PRELUDE = f""" +You are a helpful assistant participating in a dialog with a user. +The messages from the user start with "{RolePrompt.HUMAN.strip()}". +The messages from you start with "{RolePrompt.ASSISTANT.strip()}". +Reply to the last message from the user taking into account the preceding dialog history. +==================== +""".strip() + + +def _format_message(message: BaseMessage) -> str: + role = ( + RolePrompt.HUMAN + if isinstance(message, (SystemMessage, HumanMessage)) + else RolePrompt.ASSISTANT + ) + return (role + " " + message.content.lstrip()).rstrip() + + +class PseudoChatHistory(History): + stop_sequences: List[str] + + def trim( + self, count_tokens: Callable[[str], int], max_prompt_tokens: int + ) -> Tuple["PseudoChatHistory", int]: + message_tokens = [ + count_tokens(message.text) for message in self.messages + ] + prompt_tokens = sum(message_tokens) + if prompt_tokens <= max_prompt_tokens: + return self, 0 + + discarded_messages: Set[int] = set() + source_messages_count: int = 0 + last_source_message: Optional[BaseMessage] = None + for index, message in enumerate(self.messages): + if message.source_message: + source_messages_count += 1 + last_source_message = message.source_message + + if message.is_important: + continue + + discarded_messages.add(index) + prompt_tokens -= message_tokens[index] + if prompt_tokens <= max_prompt_tokens: + return ( + PseudoChatHistory.create( + messages=[ + message.source_message + for message in exclude_indices( + self.messages, discarded_messages + ) + if message.source_message + ] + ), + len(discarded_messages), + ) + + if discarded_messages: + discarded_messages_count = len(discarded_messages) + if ( + source_messages_count - discarded_messages_count == 1 + and isinstance(last_source_message, HumanMessage) + ): + history = PseudoChatHistory.create([last_source_message]) + prompt_tokens = sum( + count_tokens(message.text) for message in history.messages + ) + if prompt_tokens <= max_prompt_tokens: + return history, len(discarded_messages) + + raise ValidationError( + f"The token size of system messages and the last user message ({prompt_tokens}) exceeds" + f" prompt token limit ({max_prompt_tokens})." + ) + + raise ValidationError( + f"Prompt token size ({prompt_tokens}) exceeds prompt token limit ({max_prompt_tokens})." + ) + + @classmethod + def create(cls, messages: List[BaseMessage]) -> "PseudoChatHistory": + if len(messages) == 1 and isinstance(messages[0], HumanMessage): + single_message = messages[0] + return cls( + messages=[ + FormattedMessage( + text=single_message.content, + source_message=single_message, + ) + ], + stop_sequences=[], + ) + + formatted_messages = [FormattedMessage(text=PRELUDE)] + + for index, message in enumerate(messages): + formatted_messages.append( + FormattedMessage( + text=_format_message(message), + source_message=message, + is_important=is_important_message(messages, index), + ) + ) + + formatted_messages.append( + FormattedMessage(text=_format_message(AIMessage(content=""))) + ) + + return cls(messages=formatted_messages, stop_sequences=STOP_SEQUENCES) diff --git a/aidial_adapter_bedrock/llm/chat_emulation/types.py b/aidial_adapter_bedrock/llm/chat_emulation/types.py deleted file mode 100644 index ab7a6ce7..00000000 --- a/aidial_adapter_bedrock/llm/chat_emulation/types.py +++ /dev/null @@ -1,6 +0,0 @@ -from enum import Enum - - -class ChatEmulationType(Enum): - ZERO_MEMORY = "zero_memory" - META_CHAT = "meta_chat" diff --git a/aidial_adapter_bedrock/llm/chat_emulation/zero_memory.py b/aidial_adapter_bedrock/llm/chat_emulation/zero_memory.py deleted file mode 100644 index 18892bad..00000000 --- a/aidial_adapter_bedrock/llm/chat_emulation/zero_memory.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import List - -from aidial_adapter_bedrock.llm.exceptions import ValidationError -from aidial_adapter_bedrock.llm.message import BaseMessage - - -def emulate(prompt: List[BaseMessage]) -> str: - if len(prompt) == 0: - raise ValidationError("List of messages must not be empty") - return prompt[-1].content diff --git a/aidial_adapter_bedrock/llm/chat_emulation/zero_memory_chat.py b/aidial_adapter_bedrock/llm/chat_emulation/zero_memory_chat.py new file mode 100644 index 00000000..525b4ead --- /dev/null +++ b/aidial_adapter_bedrock/llm/chat_emulation/zero_memory_chat.py @@ -0,0 +1,33 @@ +from typing import Callable, List, Tuple + +from aidial_adapter_bedrock.llm.chat_emulation.history import ( + FormattedMessage, + History, +) +from aidial_adapter_bedrock.llm.exceptions import ValidationError +from aidial_adapter_bedrock.llm.message import BaseMessage + + +class ZeroMemoryChatHistory(History): + discarded_messages: int + + def trim( + self, count_tokens: Callable[[str], int], max_prompt_tokens: int + ) -> Tuple["ZeroMemoryChatHistory", int]: + # Possibly, not supported operation + return self, self.discarded_messages + + @classmethod + def create(cls, messages: List[BaseMessage]) -> "ZeroMemoryChatHistory": + if len(messages) == 0: + raise ValidationError("List of messages must not be empty") + + last_message = messages[-1] + return cls( + messages=[ + FormattedMessage( + text=last_message.content, source_message=last_message + ) + ], + discarded_messages=len(messages) - 1, + ) diff --git a/aidial_adapter_bedrock/llm/chat_model.py b/aidial_adapter_bedrock/llm/chat_model.py index 80420c64..b5fdb35c 100644 --- a/aidial_adapter_bedrock/llm/chat_model.py +++ b/aidial_adapter_bedrock/llm/chat_model.py @@ -1,60 +1,114 @@ from abc import ABC, abstractmethod -from typing import List, Tuple +from typing import Callable, List, Optional from aidial_sdk.chat_completion import Message from pydantic import BaseModel -import aidial_adapter_bedrock.llm.chat_emulation.claude as claude -import aidial_adapter_bedrock.llm.chat_emulation.meta_chat as meta_chat -import aidial_adapter_bedrock.llm.chat_emulation.zero_memory as zero_memory -from aidial_adapter_bedrock.llm.chat_emulation.types import ChatEmulationType -from aidial_adapter_bedrock.llm.message import BaseMessage, parse_message +from aidial_adapter_bedrock.llm.chat_emulation.pseudo_chat import ( + PseudoChatHistory, +) +from aidial_adapter_bedrock.llm.consumer import Consumer +from aidial_adapter_bedrock.llm.exceptions import ValidationError +from aidial_adapter_bedrock.llm.message import ( + BaseMessage, + SystemMessage, + parse_message, +) from aidial_adapter_bedrock.universal_api.request import ModelParameters -from aidial_adapter_bedrock.universal_api.token_usage import TokenUsage -from aidial_adapter_bedrock.utils.operators import Unary, identity -from aidial_adapter_bedrock.utils.text import enforce_stop_tokens +from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log -class ResponseData(BaseModel): - mime_type: str - name: str - content: str +def _is_empty_system_message(msg: BaseMessage) -> bool: + return isinstance(msg, SystemMessage) and msg.content.strip() == "" -class ModelResponse(BaseModel): - content: str - data: List[ResponseData] - usage: TokenUsage +class ChatPrompt(BaseModel): + text: str + stop_sequences: List[str] + discarded_messages: Optional[int] = None class ChatModel(ABC): model_id: str - model_params: ModelParameters + + def __init__(self, model_id: str): + self.model_id = model_id + + @abstractmethod + def _prepare_prompt( + self, messages: List[BaseMessage], max_prompt_tokens: Optional[int] + ) -> ChatPrompt: + pass @abstractmethod - async def acall(self, prompt: str) -> ModelResponse: - # TODO: Support multiple results: call the model in cycle of `self.model_params.n` iterations + async def _apredict( + self, consumer: Consumer, model_params: ModelParameters, prompt: str + ) -> None: pass + def _validate_and_cleanup_messages( + self, messages: List[BaseMessage] + ) -> List[BaseMessage]: + # Skipping empty system messages + messages = [ + msg for msg in messages if not _is_empty_system_message(msg) + ] + + if len(messages) == 0: + raise ValidationError("List of messages must not be empty") + + return messages + async def achat( self, - chat_emulation_type: ChatEmulationType, + consumer: Consumer, + model_params: ModelParameters, messages: List[Message], - ) -> ModelResponse: - prompt, post_process = emulate_chat( - self.model_id, - chat_emulation_type, - list(map(parse_message, messages)), + ): + base_messages = list(map(parse_message, messages)) + base_messages = self._validate_and_cleanup_messages(base_messages) + + chat_prompt = self._prepare_prompt( + base_messages, model_params.max_prompt_tokens ) - response = await self.acall(prompt) + model_params = model_params.add_stop_sequences( + chat_prompt.stop_sequences + ) - content = post_process( - enforce_stop_tokens(response.content, self.model_params.stop) + log.debug( + f"model parameters:\n{model_params.json(indent=2, exclude_none=True)}" ) + log.debug(f"prompt:\n{chat_prompt.text}") + + await self._apredict(consumer, model_params, chat_prompt.text) + + if chat_prompt.discarded_messages is not None: + consumer.set_discarded_messages(chat_prompt.discarded_messages) - return ModelResponse( - content=content, data=response.data, usage=response.usage + +class PseudoChatModel(ChatModel, ABC): + def __init__(self, model_id: str, count_tokens: Callable[[str], int]): + super().__init__(model_id) + self.count_tokens = count_tokens + + def _prepare_prompt( + self, messages: List[BaseMessage], max_prompt_tokens: Optional[int] + ) -> ChatPrompt: + history = PseudoChatHistory.create(messages) + if max_prompt_tokens is None: + return ChatPrompt( + text=history.format(), stop_sequences=history.stop_sequences + ) + + history, discarded_messages_count = history.trim( + lambda text: self.count_tokens(text), max_prompt_tokens + ) + + return ChatPrompt( + text=history.format(), + stop_sequences=history.stop_sequences, + discarded_messages=discarded_messages_count, ) @@ -67,26 +121,8 @@ def parse(cls, model_id: str) -> "Model": parts = model_id.split(".") if len(parts) != 2: raise Exception( - f"Invalid model id '{model_id}'. The model id is expected to be in format 'provider.model'" + f"Invalid model id '{model_id}'. " + "The model id is expected to be in format 'provider.model'" ) provider, model = parts return cls(provider=provider, model=model) - - -def emulate_chat( - model_id: str, emulation_type: ChatEmulationType, history: List[BaseMessage] -) -> Tuple[str, Unary[str]]: - model = Model.parse(model_id) - if model.provider == "anthropic" and "claude" in model.model: - return claude.emulate(history), identity - - if model.provider == "stability": - return zero_memory.emulate(history), identity - - match emulation_type: - case ChatEmulationType.ZERO_MEMORY: - return zero_memory.emulate(history), identity - case ChatEmulationType.META_CHAT: - return meta_chat.emulate(history) - case _: - raise Exception(f"Invalid emulation type: {emulation_type}") diff --git a/aidial_adapter_bedrock/llm/consumer.py b/aidial_adapter_bedrock/llm/consumer.py new file mode 100644 index 00000000..2d5f0f90 --- /dev/null +++ b/aidial_adapter_bedrock/llm/consumer.py @@ -0,0 +1,82 @@ +from abc import ABC, abstractmethod +from typing import List, Optional + +from aidial_sdk.chat_completion import Choice +from pydantic import BaseModel + +from aidial_adapter_bedrock.universal_api.token_usage import TokenUsage + + +class Attachment(BaseModel): + type: str | None = None + title: str | None = None + data: str | None = None + url: str | None = None + reference_url: str | None = None + reference_type: str | None = None + + +class Consumer(ABC): + @abstractmethod + def append_content(self, content: str): + pass + + @abstractmethod + def add_attachment(self, attachment: Attachment): + pass + + @abstractmethod + def add_usage(self, usage: TokenUsage): + pass + + @abstractmethod + def set_discarded_messages(self, discarded_messages: int): + pass + + +class ChoiceConsumer(Consumer): + usage: TokenUsage + choice: Choice + discarded_messages: Optional[int] + + def __init__(self, choice: Choice): + self.choice = choice + self.usage = TokenUsage() + self.discarded_messages = None + + def append_content(self, content: str): + self.choice.append_content(content) + + def add_attachment(self, attachment: Attachment): + self.choice.add_attachment(**attachment.dict()) + + def add_usage(self, usage: TokenUsage): + self.usage += usage + + def set_discarded_messages(self, discarded_messages: int): + self.discarded_messages = discarded_messages + + +class CollectConsumer(Consumer): + usage: TokenUsage + content: str + attachments: List[Attachment] + discarded_messages: Optional[int] + + def __init__(self): + self.usage = TokenUsage() + self.content = "" + self.attachments = [] + self.discarded_messages = None + + def append_content(self, content: str): + self.content += content + + def add_attachment(self, attachment: Attachment): + self.attachments.append(attachment) + + def add_usage(self, usage: TokenUsage): + self.usage += usage + + def set_discarded_messages(self, discarded_messages: int): + self.discarded_messages = discarded_messages diff --git a/aidial_adapter_bedrock/llm/model/__init__.py b/aidial_adapter_bedrock/llm/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aidial_adapter_bedrock/llm/model/adapter.py b/aidial_adapter_bedrock/llm/model/adapter.py new file mode 100644 index 00000000..8c12e136 --- /dev/null +++ b/aidial_adapter_bedrock/llm/model/adapter.py @@ -0,0 +1,40 @@ +import boto3 + +from aidial_adapter_bedrock.llm.chat_model import ChatModel, Model +from aidial_adapter_bedrock.llm.model.ai21 import AI21Adapter +from aidial_adapter_bedrock.llm.model.amazon import AmazonAdapter +from aidial_adapter_bedrock.llm.model.anthropic import AnthropicAdapter +from aidial_adapter_bedrock.llm.model.stability import StabilityAdapter +from aidial_adapter_bedrock.utils.concurrency import make_async + + +def count_tokens(string: str) -> int: + """ + The number of bytes is a proxy for the number of tokens for + models which do not provide any means to count tokens. + + Any token number estimator should satisfy the following requirements: + 1. Overestimation of number of tokens is allowed. + It's ok to trim the chat history more than necessary. + 2. Underestimation of number of tokens is prohibited. + It's wrong to leave the chat history as is when the trimming was actually required. + """ + return len(string.encode("utf-8")) + + +async def get_bedrock_adapter(model_id: str, region: str) -> ChatModel: + bedrock = await make_async( + lambda _: boto3.Session().client("bedrock-runtime", region), () + ) + model_provider = Model.parse(model_id).provider + match model_provider: + case "anthropic": + return AnthropicAdapter(bedrock, model_id) + case "ai21": + return AI21Adapter(bedrock, model_id, count_tokens) + case "stability": + return StabilityAdapter(bedrock, model_id) + case "amazon": + return AmazonAdapter(bedrock, model_id, count_tokens) + case _: + raise ValueError(f"Unknown model provider: '{model_provider}'") diff --git a/aidial_adapter_bedrock/llm/model/ai21.py b/aidial_adapter_bedrock/llm/model/ai21.py new file mode 100644 index 00000000..ed114fa9 --- /dev/null +++ b/aidial_adapter_bedrock/llm/model/ai21.py @@ -0,0 +1,136 @@ +import json +from typing import Any, Callable, Dict, List, Optional + +from pydantic import BaseModel + +from aidial_adapter_bedrock.llm.chat_model import PseudoChatModel +from aidial_adapter_bedrock.llm.consumer import Consumer +from aidial_adapter_bedrock.llm.model.conf import DEFAULT_MAX_TOKENS_AI21 +from aidial_adapter_bedrock.universal_api.request import ModelParameters +from aidial_adapter_bedrock.universal_api.token_usage import TokenUsage +from aidial_adapter_bedrock.utils.concurrency import make_async + + +class TextRange(BaseModel): + start: int + end: int + + +class GeneratedToken(BaseModel): + token: str + logprob: float + raw_logprob: float + + +class Token(BaseModel): + generatedToken: GeneratedToken + topTokens: Optional[Any] + textRange: TextRange + + +class TextAndTokens(BaseModel): + text: str + tokens: List[Token] + + +class FinishReason(BaseModel): + reason: str # Literal["length", "endoftext"] + length: Optional[int] + + +class Completion(BaseModel): + data: TextAndTokens + finishReason: FinishReason + + +class AI21Response(BaseModel): + id: int + prompt: TextAndTokens + completions: List[Completion] + + def content(self) -> str: + assert ( + len(self.completions) == 1 + ), "AI21Response should only have one completion" + return self.completions[0].data.text + + def usage(self) -> TokenUsage: + assert ( + len(self.completions) == 1 + ), "AI21Response should only have one completion" + return TokenUsage( + prompt_tokens=len(self.prompt.tokens), + completion_tokens=len(self.completions[0].data.tokens), + ) + + +# NOTE: See https://docs.ai21.com/reference/j2-instruct-ref +def prepare_model_kwargs(model_params: ModelParameters) -> Dict[str, Any]: + model_kwargs = {} + + if model_params.max_tokens is not None: + model_kwargs["maxTokens"] = model_params.max_tokens + else: + # The default for max tokens is 16, which is too small for most use cases. + # Choosing reasonable default. + model_kwargs["maxTokens"] = DEFAULT_MAX_TOKENS_AI21 + + if model_params.temperature is not None: + model_kwargs["temperature"] = model_params.temperature + else: + # The default AI21 temperature is 0.7. + # The default OpenAI temperature is 1.0. + # Choosing the OpenAI default since we pretend AI21 to be OpenAI. + model_kwargs["temperature"] = 1.0 + + if model_params.top_p is not None: + model_kwargs["topP"] = model_params.top_p + + if model_params.stop is not None: + model_kwargs["stopSequences"] = ( + [model_params.stop] + if isinstance(model_params.stop, str) + else model_params.stop + ) + + # NOTE: AI21 has "numResults" parameter, however we emulate multiple result + # via multiple calls to support all models uniformly. + + return model_kwargs + + +def prepare_input(prompt: str, model_kwargs: Dict[str, Any]) -> Dict[str, Any]: + return {"prompt": prompt, **model_kwargs} + + +class AI21Adapter(PseudoChatModel): + def __init__( + self, bedrock: Any, model_id: str, count_tokens: Callable[[str], int] + ): + super().__init__(model_id, count_tokens) + self.bedrock = bedrock + + async def _apredict( + self, consumer: Consumer, model_params: ModelParameters, prompt: str + ): + await make_async( + lambda args: self._call(*args), (consumer, model_params, prompt) + ) + + def _call( + self, consumer: Consumer, model_params: ModelParameters, prompt: str + ): + model_kwargs = prepare_model_kwargs(model_params) + + model_response = self.bedrock.invoke_model( + modelId=self.model_id, + accept="application/json", + contentType="application/json", + body=json.dumps(prepare_input(prompt, model_kwargs)), + ) + + body = json.loads(model_response["body"].read()) + resp = AI21Response.parse_obj(body) + + consumer.append_content(resp.content()) + consumer.add_usage(resp.usage()) diff --git a/aidial_adapter_bedrock/llm/model/amazon.py b/aidial_adapter_bedrock/llm/model/amazon.py new file mode 100644 index 00000000..c8ffc8b4 --- /dev/null +++ b/aidial_adapter_bedrock/llm/model/amazon.py @@ -0,0 +1,196 @@ +import json +from typing import Any, Callable, Dict, Generator, List, Optional + +from pydantic import BaseModel +from typing_extensions import override + +import aidial_adapter_bedrock.utils.stream as stream +from aidial_adapter_bedrock.llm.chat_emulation.pseudo_chat import RolePrompt +from aidial_adapter_bedrock.llm.chat_model import PseudoChatModel +from aidial_adapter_bedrock.llm.consumer import Consumer +from aidial_adapter_bedrock.llm.message import BaseMessage +from aidial_adapter_bedrock.llm.model.conf import DEFAULT_MAX_TOKENS_AMAZON +from aidial_adapter_bedrock.universal_api.request import ModelParameters +from aidial_adapter_bedrock.universal_api.token_usage import TokenUsage +from aidial_adapter_bedrock.utils.concurrency import make_async +from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log + + +class AmazonResult(BaseModel): + tokenCount: int + outputText: str + completionReason: Optional[str] + + +class AmazonResponse(BaseModel): + inputTextTokenCount: int + results: List[AmazonResult] + + def content(self) -> str: + assert ( + len(self.results) == 1 + ), "AmazonResponse should only have one result" + return self.results[0].outputText + + def usage(self) -> TokenUsage: + assert ( + len(self.results) == 1 + ), "AmazonResponse should only have one result" + return TokenUsage( + prompt_tokens=self.inputTextTokenCount, + completion_tokens=self.results[0].tokenCount, + ) + + +def prepare_model_kwargs(model_params: ModelParameters) -> Dict[str, Any]: + model_kwargs = {} + + if model_params.temperature is not None: + model_kwargs["temperature"] = model_params.temperature + + if model_params.top_p is not None: + model_kwargs["topP"] = model_params.top_p + + if model_params.max_tokens is not None: + model_kwargs["maxTokenCount"] = model_params.max_tokens + else: + # The default for max tokens is 128, which is too small for most use cases. + # Choosing reasonable default. + model_kwargs["maxTokenCount"] = DEFAULT_MAX_TOKENS_AMAZON + + # NOTE: Amazon Titan (amazon.titan-tg1-large) currently only supports + # stop sequences matching pattern "$\|+". + # if model_params.stop is not None: + # model_kwargs["stopSequences"] = model_params.stop + + return model_kwargs + + +def prepare_input(prompt: str, model_kwargs: Dict[str, Any]) -> Dict[str, Any]: + return { + "inputText": prompt, + "textGenerationConfig": model_kwargs, + } + + +def get_generator_for_streaming( + response: Any, + usage: TokenUsage, +) -> Generator[str, None, None]: + body = response["body"] + for event in body: + chunk = event.get("chunk") + if chunk: + chunk_obj = json.loads(chunk.get("bytes").decode()) + log.debug(f"chunk: {chunk_obj}") + + input_tokens = chunk_obj.get("inputTextTokenCount") + if input_tokens is not None: + usage.prompt_tokens = input_tokens + + output_tokens = chunk_obj.get("totalOutputTextTokenCount") + if output_tokens is not None: + usage.completion_tokens = output_tokens + + yield chunk_obj["outputText"] + + +def get_generator_for_non_streaming( + response: Any, + usage: TokenUsage, +) -> Generator[str, None, None]: + body = json.loads(response["body"].read()) + log.debug(f"body: {body}") + + resp = AmazonResponse.parse_obj(body) + + token_usage = resp.usage() + usage.completion_tokens = token_usage.completion_tokens + usage.prompt_tokens = token_usage.prompt_tokens + + yield resp.content() + + +def post_process_stream( + model_params: ModelParameters, content_stream: Generator[str, None, None] +) -> Generator[str, None, None]: + content_stream = stream.lstrip(content_stream) + + # Titan occasionally starts its response with the role prefix + content_stream = stream.remove_prefix( + content_stream, RolePrompt.ASSISTANT.lstrip() + " " + ) + + # Titan doesn't support stop sequences, so do it manually + if model_params.stop is not None: + stop_sequences = ( + [model_params.stop] + if isinstance(model_params.stop, str) + else model_params.stop + ) + content_stream = stream.stop_at(content_stream, stop_sequences) + + # After all the post processing, the stream may become empty. + # To avoid this, add a space to the stream. + content_stream = stream.ensure_not_empty(content_stream, " ") + + return content_stream + + +class AmazonAdapter(PseudoChatModel): + def __init__( + self, bedrock: Any, model_id: str, count_tokens: Callable[[str], int] + ): + super().__init__(model_id, count_tokens) + self.bedrock = bedrock + + @override + def _validate_and_cleanup_messages( + self, messages: List[BaseMessage] + ) -> List[BaseMessage]: + messages = super()._validate_and_cleanup_messages(messages) + + # AWS Titan doesn't support empty messages, + # so we replace it with a single space. + for msg in messages: + msg.content = msg.content or " " + + return messages + + async def _apredict( + self, consumer: Consumer, model_params: ModelParameters, prompt: str + ): + await make_async( + lambda args: self._call(*args), (consumer, model_params, prompt) + ) + + def _call( + self, consumer: Consumer, model_params: ModelParameters, prompt: str + ): + model_kwargs = prepare_model_kwargs(model_params) + + invoke_params = { + "modelId": self.model_id, + "accept": "application/json", + "contentType": "application/json", + "body": json.dumps(prepare_input(prompt, model_kwargs)), + } + + usage = TokenUsage() + + if not model_params.stream: + response = self.bedrock.invoke_model(**invoke_params) + content_stream = get_generator_for_non_streaming(response, usage) + else: + response = self.bedrock.invoke_model_with_response_stream( + **invoke_params + ) + content_stream = get_generator_for_streaming(response, usage) + + content_stream = post_process_stream(model_params, content_stream) + + for content in content_stream: + log.debug(f"content: {repr(content)}") + consumer.append_content(content) + + consumer.add_usage(usage) diff --git a/aidial_adapter_bedrock/llm/model/anthropic.py b/aidial_adapter_bedrock/llm/model/anthropic.py new file mode 100644 index 00000000..13274893 --- /dev/null +++ b/aidial_adapter_bedrock/llm/model/anthropic.py @@ -0,0 +1,140 @@ +import json +from typing import Any, Dict, Generator, List, Optional + +from anthropic.tokenizer import count_tokens + +from aidial_adapter_bedrock.llm.chat_emulation import claude_chat +from aidial_adapter_bedrock.llm.chat_emulation.claude_chat import ( + ClaudeChatHistory, +) +from aidial_adapter_bedrock.llm.chat_model import ChatModel, ChatPrompt +from aidial_adapter_bedrock.llm.consumer import Consumer +from aidial_adapter_bedrock.llm.message import BaseMessage +from aidial_adapter_bedrock.llm.model.conf import DEFAULT_MAX_TOKENS_ANTHROPIC +from aidial_adapter_bedrock.universal_api.request import ModelParameters +from aidial_adapter_bedrock.universal_api.token_usage import TokenUsage +from aidial_adapter_bedrock.utils.concurrency import make_async +from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log + + +def compute_usage(prompt: str, completion: str) -> TokenUsage: + return TokenUsage( + prompt_tokens=count_tokens(prompt), + completion_tokens=count_tokens(completion), + ) + + +# NOTE: See https://docs.anthropic.com/claude/reference/complete_post +def prepare_model_kwargs(model_params: ModelParameters) -> Dict[str, Any]: + model_kwargs = {} + + if model_params.max_tokens is not None: + model_kwargs["max_tokens_to_sample"] = model_params.max_tokens + else: + # The max tokens parameter is required for Anthropic models. + # Choosing reasonable default. + model_kwargs["max_tokens_to_sample"] = DEFAULT_MAX_TOKENS_ANTHROPIC + + if model_params.stop is not None: + model_kwargs["stop_sequences"] = ( + [model_params.stop] + if isinstance(model_params.stop, str) + else model_params.stop + ) + + if model_params.temperature is not None: + model_kwargs["temperature"] = model_params.temperature + + if model_params.top_p is not None: + model_kwargs["top_p"] = model_params.top_p + + return model_kwargs + + +def prepare_input(prompt: str, model_kwargs: Dict[str, Any]) -> Dict[str, Any]: + return {"prompt": prompt, **model_kwargs} + + +def get_generator_for_streaming(response: Any) -> Generator[str, None, None]: + body = response["body"] + for event in body: + chunk = event.get("chunk") + if chunk: + chunk_obj = json.loads(chunk.get("bytes").decode()) + log.debug(f"chunk: {chunk_obj}") + + yield chunk_obj["completion"] + + +def get_generator_for_non_streaming( + response: Any, +) -> Generator[str, None, None]: + body = json.loads(response["body"].read()) + log.debug(f"body: {body}") + yield body["completion"] + + +class AnthropicAdapter(ChatModel): + def __init__( + self, + bedrock: Any, + model_id: str, + ): + super().__init__(model_id) + self.bedrock = bedrock + + def _prepare_prompt( + self, messages: List[BaseMessage], max_prompt_tokens: Optional[int] + ) -> ChatPrompt: + history = ClaudeChatHistory.create(messages) + if max_prompt_tokens is None: + return ChatPrompt( + text=history.format(), stop_sequences=claude_chat.STOP_SEQUENCES + ) + + history, discarded_messages_count = history.trim( + count_tokens, max_prompt_tokens + ) + + return ChatPrompt( + text=history.format(), + stop_sequences=claude_chat.STOP_SEQUENCES, + discarded_messages=discarded_messages_count, + ) + + async def _apredict( + self, consumer: Consumer, model_params: ModelParameters, prompt: str + ): + return await make_async( + lambda args: self._predict(*args), (consumer, model_params, prompt) + ) + + def _predict( + self, consumer: Consumer, model_params: ModelParameters, prompt: str + ): + model_kwargs = prepare_model_kwargs(model_params) + + invoke_params = { + "modelId": self.model_id, + "accept": "application/json", + "contentType": "application/json", + "body": json.dumps(prepare_input(prompt, model_kwargs)), + } + + if not model_params.stream: + response = self.bedrock.invoke_model(**invoke_params) + content_stream = get_generator_for_non_streaming(response) + + else: + response = self.bedrock.invoke_model_with_response_stream( + **invoke_params + ) + content_stream = get_generator_for_streaming(response) + + completion = "" + + for content in content_stream: + completion += content + consumer.append_content(content) + + consumer.add_usage(compute_usage(prompt, completion)) diff --git a/aidial_adapter_bedrock/llm/model/conf.py b/aidial_adapter_bedrock/llm/model/conf.py new file mode 100644 index 00000000..dbdcedf2 --- /dev/null +++ b/aidial_adapter_bedrock/llm/model/conf.py @@ -0,0 +1,3 @@ +DEFAULT_MAX_TOKENS_AI21 = 512 +DEFAULT_MAX_TOKENS_AMAZON = 512 +DEFAULT_MAX_TOKENS_ANTHROPIC = 512 diff --git a/aidial_adapter_bedrock/llm/model/stability.py b/aidial_adapter_bedrock/llm/model/stability.py new file mode 100644 index 00000000..242af573 --- /dev/null +++ b/aidial_adapter_bedrock/llm/model/stability.py @@ -0,0 +1,123 @@ +import json +from enum import Enum +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + +from aidial_adapter_bedrock.llm.chat_emulation.zero_memory_chat import ( + ZeroMemoryChatHistory, +) +from aidial_adapter_bedrock.llm.chat_model import ChatModel, ChatPrompt +from aidial_adapter_bedrock.llm.consumer import Attachment, Consumer +from aidial_adapter_bedrock.llm.message import BaseMessage +from aidial_adapter_bedrock.universal_api.request import ModelParameters +from aidial_adapter_bedrock.universal_api.token_usage import TokenUsage +from aidial_adapter_bedrock.utils.concurrency import make_async + + +class ResponseData(BaseModel): + mime_type: str + name: str + content: str + + +class StabilityStatus(str, Enum): + SUCCESS = "success" + ERROR = "error" + + +class StabilityError(BaseModel): + id: str + message: str + name: str + + +class StabilityArtifact(BaseModel): + seed: int + base64: str + finish_reason: str = Field(alias="finishReason") + + +class StabilityResponse(BaseModel): + # TODO: Use tagged union artifacts/error + result: str + artifacts: Optional[list[StabilityArtifact]] + error: Optional[StabilityError] + + def content(self) -> str: + self._throw_if_error() + return "" + + def data(self) -> list[ResponseData]: + self._throw_if_error() + return [ + ResponseData( + mime_type="image/png", + name="image", + content=self.artifacts[0].base64, # type: ignore + ) + ] + + def usage(self) -> TokenUsage: + return TokenUsage( + prompt_tokens=0, + completion_tokens=1, + ) + + def _throw_if_error(self): + if self.result == StabilityStatus.ERROR: + raise Exception(self.error.message) # type: ignore + + +def prepare_input(prompt: str) -> Dict[str, Any]: + return {"text_prompts": [{"text": prompt}]} + + +class StabilityAdapter(ChatModel): + def __init__( + self, + bedrock: Any, + model_id: str, + ): + super().__init__(model_id) + self.bedrock = bedrock + + def _prepare_prompt( + self, messages: List[BaseMessage], max_prompt_tokens: Optional[int] + ) -> ChatPrompt: + history = ZeroMemoryChatHistory.create(messages) + return ChatPrompt( + text=history.format(), + stop_sequences=[], + discarded_messages=history.discarded_messages, + ) + + async def _apredict( + self, consumer: Consumer, model_params: ModelParameters, prompt: str + ): + return await make_async( + lambda args: self._call(*args), (consumer, prompt) + ) + + def _call(self, consumer: Consumer, prompt: str): + model_response = self.bedrock.invoke_model( + modelId=self.model_id, + accept="application/json", + contentType="application/json", + body=json.dumps(prepare_input(prompt)), + ) + + body = json.loads(model_response["body"].read()) + resp = StabilityResponse.parse_obj(body) + + consumer.append_content(resp.content()) + consumer.add_usage(resp.usage()) + + for data in resp.data(): + consumer.add_attachment( + Attachment( + title=data.name, + data=data.content, + type=data.mime_type, + ) + ) diff --git a/aidial_adapter_bedrock/llm/model_listing.py b/aidial_adapter_bedrock/llm/model_listing.py new file mode 100644 index 00000000..c2467c9c --- /dev/null +++ b/aidial_adapter_bedrock/llm/model_listing.py @@ -0,0 +1,14 @@ +from typing import List, TypedDict + +import boto3 + + +class BedrockModelId(TypedDict): + modelArn: str + modelId: str + + +def get_bedrock_models(region: str) -> List[BedrockModelId]: + session = boto3.Session() + bedrock = session.client("bedrock", region) + return bedrock.list_foundation_models()["modelSummaries"] diff --git a/aidial_adapter_bedrock/universal_api/request.py b/aidial_adapter_bedrock/universal_api/request.py index 4d21d682..c2699cc8 100644 --- a/aidial_adapter_bedrock/universal_api/request.py +++ b/aidial_adapter_bedrock/universal_api/request.py @@ -10,9 +10,11 @@ class ModelParameters(BaseModel): n: Optional[int] = None stop: Optional[Union[str, List[str]]] = None max_tokens: Optional[int] = None + max_prompt_tokens: Optional[int] = None presence_penalty: Optional[float] = None frequency_penalty: Optional[float] = None logit_bias: Optional[Mapping[int, float]] = None + stream: bool = False @classmethod def create(cls, request: Request) -> "ModelParameters": @@ -22,7 +24,22 @@ def create(cls, request: Request) -> "ModelParameters": n=request.n, stop=request.stop, max_tokens=request.max_tokens, + max_prompt_tokens=request.max_prompt_tokens, presence_penalty=request.presence_penalty, frequency_penalty=request.frequency_penalty, logit_bias=request.logit_bias, + stream=request.stream, ) + + def add_stop_sequences(self, stop: List[str]) -> "ModelParameters": + if len(stop) == 0: + return self + + self_stop: List[str] = [] + if self.stop is not None: + if isinstance(self.stop, str): + self_stop = [self.stop] + else: + self_stop = self.stop + + return self.copy(update={"stop": [*self_stop, *stop]}) diff --git a/aidial_adapter_bedrock/universal_api/token_usage.py b/aidial_adapter_bedrock/universal_api/token_usage.py index 814232b7..194312ed 100644 --- a/aidial_adapter_bedrock/universal_api/token_usage.py +++ b/aidial_adapter_bedrock/universal_api/token_usage.py @@ -9,8 +9,10 @@ class TokenUsage(BaseModel): def total_tokens(self) -> int: return self.prompt_tokens + self.completion_tokens + def accumulate(self, other: "TokenUsage") -> "TokenUsage": + self.prompt_tokens += other.prompt_tokens + self.completion_tokens += other.completion_tokens + return self + def __add__(self, other: "TokenUsage") -> "TokenUsage": - return TokenUsage( - prompt_tokens=self.prompt_tokens + other.prompt_tokens, - completion_tokens=self.completion_tokens + other.completion_tokens, - ) + return self.copy().accumulate(other) diff --git a/aidial_adapter_bedrock/utils/list.py b/aidial_adapter_bedrock/utils/list.py new file mode 100644 index 00000000..1216f577 --- /dev/null +++ b/aidial_adapter_bedrock/utils/list.py @@ -0,0 +1,9 @@ +from typing import List, Set, TypeVar + +T = TypeVar("T") + + +def exclude_indices(input_list: List[T], indices: Set[int]) -> List[T]: + return [ + item for index, item in enumerate(input_list) if index not in indices + ] diff --git a/aidial_adapter_bedrock/utils/log_config.py b/aidial_adapter_bedrock/utils/log_config.py index badf519f..2a5e52da 100644 --- a/aidial_adapter_bedrock/utils/log_config.py +++ b/aidial_adapter_bedrock/utils/log_config.py @@ -1,12 +1,16 @@ import logging import os +from aidial_sdk import logger as aidial_logger from pydantic import BaseModel # By default (in prod) we don't want to print debug messages, # because they typically contain prompts. LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") +AIDIAL_LOG_LEVEL = os.getenv("AIDIAL_LOG_LEVEL", "WARNING") +aidial_logger.setLevel(AIDIAL_LOG_LEVEL) + class LogConfig(BaseModel): """Logging configuration to be set for the server""" @@ -16,7 +20,7 @@ class LogConfig(BaseModel): formatters = { "default": { "()": "uvicorn.logging.DefaultFormatter", - "fmt": "%(levelprefix)s | %(asctime)s | %(name)s | %(process)d | %(message)s", + "fmt": "%(levelprefix)s | %(asctime)s | %(process)d | %(name)s | %(message)s", "datefmt": "%Y-%m-%d %H:%M:%S", "use_colors": True, }, diff --git a/aidial_adapter_bedrock/utils/operators.py b/aidial_adapter_bedrock/utils/operators.py deleted file mode 100644 index eaa191a5..00000000 --- a/aidial_adapter_bedrock/utils/operators.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import Callable, TypeVar - -T = TypeVar("T") - -Unary = Callable[[T], T] - - -def identity(x: T) -> T: - return x diff --git a/aidial_adapter_bedrock/utils/stream.py b/aidial_adapter_bedrock/utils/stream.py new file mode 100644 index 00000000..fb5b7daa --- /dev/null +++ b/aidial_adapter_bedrock/utils/stream.py @@ -0,0 +1,78 @@ +from typing import Generator, List + +import tests.utils.string as string + + +def lstrip(gen: Generator[str, None, None]) -> Generator[str, None, None]: + start = True + for chunk in gen: + if start: + chunk = chunk.lstrip() + if chunk != "": + start = False + yield chunk + else: + yield chunk + + +def remove_prefix( + gen: Generator[str, None, None], prefix: str +) -> Generator[str, None, None]: + acc = "" + start = True + + for chunk in gen: + if start: + acc += chunk + if len(acc) >= len(prefix): + yield string.remove_prefix(prefix, acc) + start = False + else: + yield chunk + + if start: + yield acc + + +def stop_at( + gen: Generator[str, None, None], stop_sequences: List[str] +) -> Generator[str, None, None]: + if len(stop_sequences) == 0: + yield from gen + return + + buffer_len = max(map(len, stop_sequences)) - 1 + + hold = "" + for chunk in gen: + hold += chunk + + min_index = len(hold) + for stop_sequence in stop_sequences: + if stop_sequence in hold: + min_index = min(min_index, hold.index(stop_sequence)) + + if min_index < len(hold): + commit = hold[:min_index] + if commit: + yield commit + return + + commit, hold = hold[:-buffer_len], hold[-buffer_len:] + if commit: + yield commit + + if hold: + yield hold + + +def ensure_not_empty( + gen: Generator[str, None, None], default: str +) -> Generator[str, None, None]: + all_chunks_are_empty = True + for chunk in gen: + all_chunks_are_empty = all_chunks_are_empty and chunk == "" + yield chunk + + if all_chunks_are_empty: + yield default diff --git a/aidial_adapter_bedrock/utils/text.py b/aidial_adapter_bedrock/utils/text.py deleted file mode 100644 index b8986e84..00000000 --- a/aidial_adapter_bedrock/utils/text.py +++ /dev/null @@ -1,22 +0,0 @@ -import re -from typing import List - - -def remove_prefix(text: str, prefix: str) -> str: - if text.startswith(prefix): - return text[len(prefix) :] - return text - - -# Copy of langchain.llms.utils::enforce_stop_tokens with a bugfix: stop words are escaped. -def enforce_stop_tokens(text: str, stop: None | List[str] | str) -> str: - """Cut off the text as soon as any stop words occur.""" - - if stop is None: - return text - - if isinstance(stop, str): - stop = [stop] - - stop_escaped = [re.escape(s) for s in stop] - return re.split("|".join(stop_escaped), text)[0] diff --git a/client/client_bedrock.py b/client/client_bedrock.py index 29f90802..eba326eb 100755 --- a/client/client_bedrock.py +++ b/client/client_bedrock.py @@ -3,12 +3,14 @@ from aidial_sdk.chat_completion import Message, Role -from aidial_adapter_bedrock.llm.bedrock_adapter import BedrockAdapter +from aidial_adapter_bedrock.llm.bedrock_models import BedrockDeployment +from aidial_adapter_bedrock.llm.consumer import CollectConsumer +from aidial_adapter_bedrock.llm.model.adapter import get_bedrock_adapter from aidial_adapter_bedrock.universal_api.request import ModelParameters from aidial_adapter_bedrock.utils.env import get_env from aidial_adapter_bedrock.utils.printing import print_ai, print_info from client.conf import MAX_CHAT_TURN, MAX_INPUT_CHARS -from client.utils.cli import choose_deployment +from client.utils.cli import select_enum from client.utils.init import init from client.utils.input import make_input @@ -16,15 +18,16 @@ async def main(): location = get_env("DEFAULT_REGION") - deployment, chat_emulation_type = choose_deployment() + deployment = select_enum("Select the deployment", BedrockDeployment) - model = await BedrockAdapter.create( + model_params = ModelParameters() + + model = await get_bedrock_adapter( model_id=deployment.get_model_id(), - model_params=ModelParameters(), region=location, ) - history: List[Message] = [] + messages: List[Message] = [] chat_input = make_input() @@ -33,14 +36,15 @@ async def main(): turn += 1 content = chat_input()[:MAX_INPUT_CHARS] - history.append(Message(role=Role.USER, content=content)) + messages.append(Message(role=Role.USER, content=content)) - response = await model.achat(chat_emulation_type, history) + response = CollectConsumer() + await model.achat(response, model_params, messages) print_info(response.usage.json(indent=2)) print_ai(response.content.strip()) - history.append(Message(role=Role.ASSISTANT, content=response.content)) + messages.append(Message(role=Role.ASSISTANT, content=response.content)) if __name__ == "__main__": diff --git a/client/utils/cli.py b/client/utils/cli.py index b0968077..efa3882d 100644 --- a/client/utils/cli.py +++ b/client/utils/cli.py @@ -1,11 +1,8 @@ from enum import Enum -from typing import List, Tuple, Type, TypeVar +from typing import List, Type, TypeVar import inquirer -from aidial_adapter_bedrock.llm.bedrock_models import BedrockDeployment -from aidial_adapter_bedrock.llm.chat_emulation.types import ChatEmulationType - V = TypeVar("V") @@ -34,12 +31,3 @@ def select_enum(title: str, enum: Type[T]) -> T: ), ] return inquirer.prompt(questions)["option"] # type: ignore - - -def choose_deployment() -> Tuple[BedrockDeployment, ChatEmulationType]: - deployment = select_enum("Select the deployment", BedrockDeployment) - chat_emulation_type = select_enum( - "Select chat emulation type", ChatEmulationType - ) - - return deployment, chat_emulation_type diff --git a/client/utils/input.py b/client/utils/input.py index c54f68d7..87e526e6 100644 --- a/client/utils/input.py +++ b/client/utils/input.py @@ -15,7 +15,7 @@ def input(prompt_text="> ", style=Style.from_dict({"": "#ff0000"})): history=FileHistory(str(get_project_root() / ".history")) ) - response = session.prompt(prompt_text, style=style) + response = session.prompt(prompt_text, style=style, in_thread=True) return response[:limit] return input diff --git a/poetry.lock b/poetry.lock index a86d5f2b..45b3df0a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -21,13 +21,13 @@ aws = ["aws-requests-auth", "boto3", "sagemaker"] [[package]] name = "aidial-sdk" -version = "0.1.0" +version = "0.1.2" description = "Framework to create applications and model adapters for AI DIAL" optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "aidial_sdk-0.1.0-py3-none-any.whl", hash = "sha256:596c9e7aca688e56b1749fb70b0c97ebd508827b6a39cfe6035a3b860cf9f7af"}, - {file = "aidial_sdk-0.1.0.tar.gz", hash = "sha256:fe8fa9ea9d3ccd3f9e719daac08d8dd946f423cb4f2511d9ec43bcc747ef51ad"}, + {file = "aidial_sdk-0.1.2-py3-none-any.whl", hash = "sha256:cb930e72964bae4eac59dbcf69d3fb4df3051972c57b4c073488a402873e26a7"}, + {file = "aidial_sdk-0.1.2.tar.gz", hash = "sha256:005b63fa2559debac41bd0d4ddd1bc3b764fb3373c3c83942f3823c8536d4371"}, ] [package.dependencies] @@ -1235,16 +1235,6 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -2056,103 +2046,6 @@ files = [ attrs = ">=22.2.0" rpds-py = ">=0.7.0" -[[package]] -name = "regex" -version = "2023.10.3" -description = "Alternative regular expression module, to replace re." -optional = false -python-versions = ">=3.7" -files = [ - {file = "regex-2023.10.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4c34d4f73ea738223a094d8e0ffd6d2c1a1b4c175da34d6b0de3d8d69bee6bcc"}, - {file = "regex-2023.10.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a8f4e49fc3ce020f65411432183e6775f24e02dff617281094ba6ab079ef0915"}, - {file = "regex-2023.10.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4cd1bccf99d3ef1ab6ba835308ad85be040e6a11b0977ef7ea8c8005f01a3c29"}, - {file = "regex-2023.10.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:81dce2ddc9f6e8f543d94b05d56e70d03a0774d32f6cca53e978dc01e4fc75b8"}, - {file = "regex-2023.10.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9c6b4d23c04831e3ab61717a707a5d763b300213db49ca680edf8bf13ab5d91b"}, - {file = "regex-2023.10.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c15ad0aee158a15e17e0495e1e18741573d04eb6da06d8b84af726cfc1ed02ee"}, - {file = "regex-2023.10.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6239d4e2e0b52c8bd38c51b760cd870069f0bdf99700a62cd509d7a031749a55"}, - {file = "regex-2023.10.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:4a8bf76e3182797c6b1afa5b822d1d5802ff30284abe4599e1247be4fd6b03be"}, - {file = "regex-2023.10.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d9c727bbcf0065cbb20f39d2b4f932f8fa1631c3e01fcedc979bd4f51fe051c5"}, - {file = "regex-2023.10.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:3ccf2716add72f80714b9a63899b67fa711b654be3fcdd34fa391d2d274ce767"}, - {file = "regex-2023.10.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:107ac60d1bfdc3edb53be75e2a52aff7481b92817cfdddd9b4519ccf0e54a6ff"}, - {file = "regex-2023.10.3-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:00ba3c9818e33f1fa974693fb55d24cdc8ebafcb2e4207680669d8f8d7cca79a"}, - {file = "regex-2023.10.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f0a47efb1dbef13af9c9a54a94a0b814902e547b7f21acb29434504d18f36e3a"}, - {file = "regex-2023.10.3-cp310-cp310-win32.whl", hash = "sha256:36362386b813fa6c9146da6149a001b7bd063dabc4d49522a1f7aa65b725c7ec"}, - {file = "regex-2023.10.3-cp310-cp310-win_amd64.whl", hash = "sha256:c65a3b5330b54103e7d21cac3f6bf3900d46f6d50138d73343d9e5b2900b2353"}, - {file = "regex-2023.10.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:90a79bce019c442604662d17bf69df99090e24cdc6ad95b18b6725c2988a490e"}, - {file = "regex-2023.10.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c7964c2183c3e6cce3f497e3a9f49d182e969f2dc3aeeadfa18945ff7bdd7051"}, - {file = "regex-2023.10.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ef80829117a8061f974b2fda8ec799717242353bff55f8a29411794d635d964"}, - {file = "regex-2023.10.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5addc9d0209a9afca5fc070f93b726bf7003bd63a427f65ef797a931782e7edc"}, - {file = "regex-2023.10.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c148bec483cc4b421562b4bcedb8e28a3b84fcc8f0aa4418e10898f3c2c0eb9b"}, - {file = "regex-2023.10.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d1f21af4c1539051049796a0f50aa342f9a27cde57318f2fc41ed50b0dbc4ac"}, - {file = "regex-2023.10.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0b9ac09853b2a3e0d0082104036579809679e7715671cfbf89d83c1cb2a30f58"}, - {file = "regex-2023.10.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ebedc192abbc7fd13c5ee800e83a6df252bec691eb2c4bedc9f8b2e2903f5e2a"}, - {file = "regex-2023.10.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d8a993c0a0ffd5f2d3bda23d0cd75e7086736f8f8268de8a82fbc4bd0ac6791e"}, - {file = "regex-2023.10.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:be6b7b8d42d3090b6c80793524fa66c57ad7ee3fe9722b258aec6d0672543fd0"}, - {file = "regex-2023.10.3-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4023e2efc35a30e66e938de5aef42b520c20e7eda7bb5fb12c35e5d09a4c43f6"}, - {file = "regex-2023.10.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0d47840dc05e0ba04fe2e26f15126de7c755496d5a8aae4a08bda4dd8d646c54"}, - {file = "regex-2023.10.3-cp311-cp311-win32.whl", hash = "sha256:9145f092b5d1977ec8c0ab46e7b3381b2fd069957b9862a43bd383e5c01d18c2"}, - {file = "regex-2023.10.3-cp311-cp311-win_amd64.whl", hash = "sha256:b6104f9a46bd8743e4f738afef69b153c4b8b592d35ae46db07fc28ae3d5fb7c"}, - {file = "regex-2023.10.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:bff507ae210371d4b1fe316d03433ac099f184d570a1a611e541923f78f05037"}, - {file = "regex-2023.10.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:be5e22bbb67924dea15039c3282fa4cc6cdfbe0cbbd1c0515f9223186fc2ec5f"}, - {file = "regex-2023.10.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a992f702c9be9c72fa46f01ca6e18d131906a7180950958f766c2aa294d4b41"}, - {file = "regex-2023.10.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7434a61b158be563c1362d9071358f8ab91b8d928728cd2882af060481244c9e"}, - {file = "regex-2023.10.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c2169b2dcabf4e608416f7f9468737583ce5f0a6e8677c4efbf795ce81109d7c"}, - {file = "regex-2023.10.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9e908ef5889cda4de038892b9accc36d33d72fb3e12c747e2799a0e806ec841"}, - {file = "regex-2023.10.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:12bd4bc2c632742c7ce20db48e0d99afdc05e03f0b4c1af90542e05b809a03d9"}, - {file = "regex-2023.10.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:bc72c231f5449d86d6c7d9cc7cd819b6eb30134bb770b8cfdc0765e48ef9c420"}, - {file = "regex-2023.10.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bce8814b076f0ce5766dc87d5a056b0e9437b8e0cd351b9a6c4e1134a7dfbda9"}, - {file = "regex-2023.10.3-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:ba7cd6dc4d585ea544c1412019921570ebd8a597fabf475acc4528210d7c4a6f"}, - {file = "regex-2023.10.3-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b0c7d2f698e83f15228ba41c135501cfe7d5740181d5903e250e47f617eb4292"}, - {file = "regex-2023.10.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5a8f91c64f390ecee09ff793319f30a0f32492e99f5dc1c72bc361f23ccd0a9a"}, - {file = "regex-2023.10.3-cp312-cp312-win32.whl", hash = "sha256:ad08a69728ff3c79866d729b095872afe1e0557251da4abb2c5faff15a91d19a"}, - {file = "regex-2023.10.3-cp312-cp312-win_amd64.whl", hash = "sha256:39cdf8d141d6d44e8d5a12a8569d5a227f645c87df4f92179bd06e2e2705e76b"}, - {file = "regex-2023.10.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:4a3ee019a9befe84fa3e917a2dd378807e423d013377a884c1970a3c2792d293"}, - {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76066d7ff61ba6bf3cb5efe2428fc82aac91802844c022d849a1f0f53820502d"}, - {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bfe50b61bab1b1ec260fa7cd91106fa9fece57e6beba05630afe27c71259c59b"}, - {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9fd88f373cb71e6b59b7fa597e47e518282455c2734fd4306a05ca219a1991b0"}, - {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3ab05a182c7937fb374f7e946f04fb23a0c0699c0450e9fb02ef567412d2fa3"}, - {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dac37cf08fcf2094159922edc7a2784cfcc5c70f8354469f79ed085f0328ebdf"}, - {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e54ddd0bb8fb626aa1f9ba7b36629564544954fff9669b15da3610c22b9a0991"}, - {file = "regex-2023.10.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3367007ad1951fde612bf65b0dffc8fd681a4ab98ac86957d16491400d661302"}, - {file = "regex-2023.10.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:16f8740eb6dbacc7113e3097b0a36065a02e37b47c936b551805d40340fb9971"}, - {file = "regex-2023.10.3-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:f4f2ca6df64cbdd27f27b34f35adb640b5d2d77264228554e68deda54456eb11"}, - {file = "regex-2023.10.3-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:39807cbcbe406efca2a233884e169d056c35aa7e9f343d4e78665246a332f597"}, - {file = "regex-2023.10.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:7eece6fbd3eae4a92d7c748ae825cbc1ee41a89bb1c3db05b5578ed3cfcfd7cb"}, - {file = "regex-2023.10.3-cp37-cp37m-win32.whl", hash = "sha256:ce615c92d90df8373d9e13acddd154152645c0dc060871abf6bd43809673d20a"}, - {file = "regex-2023.10.3-cp37-cp37m-win_amd64.whl", hash = "sha256:0f649fa32fe734c4abdfd4edbb8381c74abf5f34bc0b3271ce687b23729299ed"}, - {file = "regex-2023.10.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9b98b7681a9437262947f41c7fac567c7e1f6eddd94b0483596d320092004533"}, - {file = "regex-2023.10.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:91dc1d531f80c862441d7b66c4505cd6ea9d312f01fb2f4654f40c6fdf5cc37a"}, - {file = "regex-2023.10.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82fcc1f1cc3ff1ab8a57ba619b149b907072e750815c5ba63e7aa2e1163384a4"}, - {file = "regex-2023.10.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7979b834ec7a33aafae34a90aad9f914c41fd6eaa8474e66953f3f6f7cbd4368"}, - {file = "regex-2023.10.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ef71561f82a89af6cfcbee47f0fabfdb6e63788a9258e913955d89fdd96902ab"}, - {file = "regex-2023.10.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd829712de97753367153ed84f2de752b86cd1f7a88b55a3a775eb52eafe8a94"}, - {file = "regex-2023.10.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:00e871d83a45eee2f8688d7e6849609c2ca2a04a6d48fba3dff4deef35d14f07"}, - {file = "regex-2023.10.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:706e7b739fdd17cb89e1fbf712d9dc21311fc2333f6d435eac2d4ee81985098c"}, - {file = "regex-2023.10.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:cc3f1c053b73f20c7ad88b0d1d23be7e7b3901229ce89f5000a8399746a6e039"}, - {file = "regex-2023.10.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6f85739e80d13644b981a88f529d79c5bdf646b460ba190bffcaf6d57b2a9863"}, - {file = "regex-2023.10.3-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:741ba2f511cc9626b7561a440f87d658aabb3d6b744a86a3c025f866b4d19e7f"}, - {file = "regex-2023.10.3-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:e77c90ab5997e85901da85131fd36acd0ed2221368199b65f0d11bca44549711"}, - {file = "regex-2023.10.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:979c24cbefaf2420c4e377ecd1f165ea08cc3d1fbb44bdc51bccbbf7c66a2cb4"}, - {file = "regex-2023.10.3-cp38-cp38-win32.whl", hash = "sha256:58837f9d221744d4c92d2cf7201c6acd19623b50c643b56992cbd2b745485d3d"}, - {file = "regex-2023.10.3-cp38-cp38-win_amd64.whl", hash = "sha256:c55853684fe08d4897c37dfc5faeff70607a5f1806c8be148f1695be4a63414b"}, - {file = "regex-2023.10.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2c54e23836650bdf2c18222c87f6f840d4943944146ca479858404fedeb9f9af"}, - {file = "regex-2023.10.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:69c0771ca5653c7d4b65203cbfc5e66db9375f1078689459fe196fe08b7b4930"}, - {file = "regex-2023.10.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ac965a998e1388e6ff2e9781f499ad1eaa41e962a40d11c7823c9952c77123e"}, - {file = "regex-2023.10.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1c0e8fae5b27caa34177bdfa5a960c46ff2f78ee2d45c6db15ae3f64ecadde14"}, - {file = "regex-2023.10.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6c56c3d47da04f921b73ff9415fbaa939f684d47293f071aa9cbb13c94afc17d"}, - {file = "regex-2023.10.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ef1e014eed78ab650bef9a6a9cbe50b052c0aebe553fb2881e0453717573f52"}, - {file = "regex-2023.10.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d29338556a59423d9ff7b6eb0cb89ead2b0875e08fe522f3e068b955c3e7b59b"}, - {file = "regex-2023.10.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:9c6d0ced3c06d0f183b73d3c5920727268d2201aa0fe6d55c60d68c792ff3588"}, - {file = "regex-2023.10.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:994645a46c6a740ee8ce8df7911d4aee458d9b1bc5639bc968226763d07f00fa"}, - {file = "regex-2023.10.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:66e2fe786ef28da2b28e222c89502b2af984858091675044d93cb50e6f46d7af"}, - {file = "regex-2023.10.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:11175910f62b2b8c055f2b089e0fedd694fe2be3941b3e2633653bc51064c528"}, - {file = "regex-2023.10.3-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:06e9abc0e4c9ab4779c74ad99c3fc10d3967d03114449acc2c2762ad4472b8ca"}, - {file = "regex-2023.10.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:fb02e4257376ae25c6dd95a5aec377f9b18c09be6ebdefa7ad209b9137b73d48"}, - {file = "regex-2023.10.3-cp39-cp39-win32.whl", hash = "sha256:3b2c3502603fab52d7619b882c25a6850b766ebd1b18de3df23b2f939360e1bd"}, - {file = "regex-2023.10.3-cp39-cp39-win_amd64.whl", hash = "sha256:adbccd17dcaff65704c856bd29951c58a1bd4b2b0f8ad6b826dbd543fe740988"}, - {file = "regex-2023.10.3.tar.gz", hash = "sha256:3fef4f844d2290ee0ba57addcec17eec9e3df73f10a2748485dfd6a3a188cc0f"}, -] - [[package]] name = "requests" version = "2.31.0" @@ -2420,14 +2313,6 @@ files = [ {file = "SQLAlchemy-2.0.21-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b69f1f754d92eb1cc6b50938359dead36b96a1dcf11a8670bff65fd9b21a4b09"}, {file = "SQLAlchemy-2.0.21-cp311-cp311-win32.whl", hash = "sha256:af520a730d523eab77d754f5cf44cc7dd7ad2d54907adeb3233177eeb22f271b"}, {file = "SQLAlchemy-2.0.21-cp311-cp311-win_amd64.whl", hash = "sha256:141675dae56522126986fa4ca713739d00ed3a6f08f3c2eb92c39c6dfec463ce"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:56628ca27aa17b5890391ded4e385bf0480209726f198799b7e980c6bd473bd7"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:db726be58837fe5ac39859e0fa40baafe54c6d54c02aba1d47d25536170b690f"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e7421c1bfdbb7214313919472307be650bd45c4dc2fcb317d64d078993de045b"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:632784f7a6f12cfa0e84bf2a5003b07660addccf5563c132cd23b7cc1d7371a9"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:f6f7276cf26145a888f2182a98f204541b519d9ea358a65d82095d9c9e22f917"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2a1f7ffac934bc0ea717fa1596f938483fb8c402233f9b26679b4f7b38d6ab6e"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-win32.whl", hash = "sha256:bfece2f7cec502ec5f759bbc09ce711445372deeac3628f6fa1c16b7fb45b682"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-win_amd64.whl", hash = "sha256:526b869a0f4f000d8d8ee3409d0becca30ae73f494cbb48801da0129601f72c6"}, {file = "SQLAlchemy-2.0.21-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7614f1eab4336df7dd6bee05bc974f2b02c38d3d0c78060c5faa4cd1ca2af3b8"}, {file = "SQLAlchemy-2.0.21-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d59cb9e20d79686aa473e0302e4a82882d7118744d30bb1dfb62d3c47141b3ec"}, {file = "SQLAlchemy-2.0.21-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a95aa0672e3065d43c8aa80080cdd5cc40fe92dc873749e6c1cf23914c4b83af"}, @@ -2525,51 +2410,6 @@ files = [ [package.extras] doc = ["reno", "sphinx", "tornado (>=4.5)"] -[[package]] -name = "tiktoken" -version = "0.4.0" -description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models" -optional = false -python-versions = ">=3.8" -files = [ - {file = "tiktoken-0.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:176cad7f053d2cc82ce7e2a7c883ccc6971840a4b5276740d0b732a2b2011f8a"}, - {file = "tiktoken-0.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:450d504892b3ac80207700266ee87c932df8efea54e05cefe8613edc963c1285"}, - {file = "tiktoken-0.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00d662de1e7986d129139faf15e6a6ee7665ee103440769b8dedf3e7ba6ac37f"}, - {file = "tiktoken-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5727d852ead18b7927b8adf558a6f913a15c7766725b23dbe21d22e243041b28"}, - {file = "tiktoken-0.4.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c06cd92b09eb0404cedce3702fa866bf0d00e399439dad3f10288ddc31045422"}, - {file = "tiktoken-0.4.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9ec161e40ed44e4210d3b31e2ff426b4a55e8254f1023e5d2595cb60044f8ea6"}, - {file = "tiktoken-0.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:1e8fa13cf9889d2c928b9e258e9dbbbf88ab02016e4236aae76e3b4f82dd8288"}, - {file = "tiktoken-0.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:bb2341836b725c60d0ab3c84970b9b5f68d4b733a7bcb80fb25967e5addb9920"}, - {file = "tiktoken-0.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2ca30367ad750ee7d42fe80079d3092bd35bb266be7882b79c3bd159b39a17b0"}, - {file = "tiktoken-0.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3dc3df19ddec79435bb2a94ee46f4b9560d0299c23520803d851008445671197"}, - {file = "tiktoken-0.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d980fa066e962ef0f4dad0222e63a484c0c993c7a47c7dafda844ca5aded1f3"}, - {file = "tiktoken-0.4.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:329f548a821a2f339adc9fbcfd9fc12602e4b3f8598df5593cfc09839e9ae5e4"}, - {file = "tiktoken-0.4.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b1a038cee487931a5caaef0a2e8520e645508cde21717eacc9af3fbda097d8bb"}, - {file = "tiktoken-0.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:08efa59468dbe23ed038c28893e2a7158d8c211c3dd07f2bbc9a30e012512f1d"}, - {file = "tiktoken-0.4.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f3020350685e009053829c1168703c346fb32c70c57d828ca3742558e94827a9"}, - {file = "tiktoken-0.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ba16698c42aad8190e746cd82f6a06769ac7edd415d62ba027ea1d99d958ed93"}, - {file = "tiktoken-0.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c15d9955cc18d0d7ffcc9c03dc51167aedae98542238b54a2e659bd25fe77ed"}, - {file = "tiktoken-0.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64e1091c7103100d5e2c6ea706f0ec9cd6dc313e6fe7775ef777f40d8c20811e"}, - {file = "tiktoken-0.4.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:e87751b54eb7bca580126353a9cf17a8a8eaadd44edaac0e01123e1513a33281"}, - {file = "tiktoken-0.4.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e063b988b8ba8b66d6cc2026d937557437e79258095f52eaecfafb18a0a10c03"}, - {file = "tiktoken-0.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:9c6dd439e878172dc163fced3bc7b19b9ab549c271b257599f55afc3a6a5edef"}, - {file = "tiktoken-0.4.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8d1d97f83697ff44466c6bef5d35b6bcdb51e0125829a9c0ed1e6e39fb9a08fb"}, - {file = "tiktoken-0.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1b6bce7c68aa765f666474c7c11a7aebda3816b58ecafb209afa59c799b0dd2d"}, - {file = "tiktoken-0.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5a73286c35899ca51d8d764bc0b4d60838627ce193acb60cc88aea60bddec4fd"}, - {file = "tiktoken-0.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d0394967d2236a60fd0aacef26646b53636423cc9c70c32f7c5124ebe86f3093"}, - {file = "tiktoken-0.4.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:dae2af6f03ecba5f679449fa66ed96585b2fa6accb7fd57d9649e9e398a94f44"}, - {file = "tiktoken-0.4.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:55e251b1da3c293432179cf7c452cfa35562da286786be5a8b1ee3405c2b0dd2"}, - {file = "tiktoken-0.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:c835d0ee1f84a5aa04921717754eadbc0f0a56cf613f78dfc1cf9ad35f6c3fea"}, - {file = "tiktoken-0.4.0.tar.gz", hash = "sha256:59b20a819969735b48161ced9b92f05dc4519c17be4015cfb73b65270a243620"}, -] - -[package.dependencies] -regex = ">=2022.1.18" -requests = ">=2.26.0" - -[package.extras] -blobfile = ["blobfile (>=2)"] - [[package]] name = "tokenizers" version = "0.14.1" @@ -3013,4 +2853,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.11,<4.0" -content-hash = "cbcf380c5856882d238619f3aec8709e271cf95eae7a4573c8d8e33ec4be1aca" +content-hash = "3fb6f843b40e9f3d42d02652251c2bb28343a3bb5104b669edb2eede4b7441d6" diff --git a/pyproject.toml b/pyproject.toml index b7271393..eaeb5807 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,13 +17,12 @@ repository = "https://github.com/epam/ai-dial-adapter-bedrock/" python = "^3.11,<4.0" boto3 = "1.28.57" botocore = "1.31.57" -aidial-sdk = "0.1.0" +aidial-sdk = "0.1.2" anthropic = "0.2.10" colorama = "0.4.4" fastapi = "0.103.1" flask = "2.3.2" openai = "0.27.8" -tiktoken = "0.4.0" uvicorn = "0.23.2" pydantic = "1.10.12" diff --git a/scripts/find_token_limits.py b/scripts/find_token_limits.py deleted file mode 100755 index 3cade588..00000000 --- a/scripts/find_token_limits.py +++ /dev/null @@ -1,58 +0,0 @@ -import asyncio - -from aidial_sdk.chat_completion import Message, Role - -from aidial_adapter_bedrock.llm.bedrock_adapter import BedrockAdapter -from aidial_adapter_bedrock.llm.bedrock_models import BedrockDeployment -from aidial_adapter_bedrock.llm.chat_emulation.types import ChatEmulationType -from aidial_adapter_bedrock.universal_api.request import ModelParameters -from aidial_adapter_bedrock.utils.env import get_env -from aidial_adapter_bedrock.utils.printing import print_error, print_info -from client.utils.cli import select_enum -from client.utils.init import init - - -async def main(): - init() - - model_id = select_enum("Select model", BedrockDeployment) - - model = await BedrockAdapter.create( - model_id=model_id, - model_params=ModelParameters(max_tokens=1), - region=get_env("DEFAULT_REGION"), - ) - - base = "a " - - min_x = 1 - max_x = 100 * 1000 - x = 1 - - while True: - prompt: str = x * base - print(f"{min_x} <= {x} <= {max_x}") - - try: - response = await model.achat( - ChatEmulationType.ZERO_MEMORY, - [Message(role=Role.USER, content=prompt)], - ) - - print_info(f"{x}: " + response.usage.json(indent=2)) - min_x = x - next_x = (x + max_x) // 2 - - except Exception as e: - print_error(f"{x}: {str(e)}") - max_x = x - next_x = (min_x + x) // 2 - - if next_x == x: - break - else: - x = next_x - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/tests/integration_tests/test_chat_completion.py b/tests/integration_tests/test_chat_completion.py index 24e04c8c..8725d51d 100644 --- a/tests/integration_tests/test_chat_completion.py +++ b/tests/integration_tests/test_chat_completion.py @@ -1,6 +1,6 @@ import re from dataclasses import dataclass -from typing import Callable, List +from typing import Callable, List, Optional import openai import openai.error @@ -10,6 +10,7 @@ from aidial_adapter_bedrock.llm.bedrock_models import BedrockDeployment from tests.conftest import TEST_SERVER_URL from tests.utils.llm import ( + ai, create_model, run_model, sanitize_test_name, @@ -25,13 +26,17 @@ class TestCase: name: str deployment: BedrockDeployment streaming: bool + max_tokens: Optional[int] + stop: Optional[List[str]] messages: List[BaseMessage] test: Callable[[str], bool] | Exception def get_id(self): + max_tokens_str = str(self.max_tokens) if self.max_tokens else "inf" + stop_sequence_str = str(self.stop) if self.stop else "nonstop" return sanitize_test_name( - f"{self.deployment.value} {self.streaming} {self.name}" + f"{self.deployment.value} {self.streaming} {max_tokens_str} {stop_sequence_str} {self.name}" ) @@ -52,12 +57,30 @@ def get_test_cases( ) -> List[TestCase]: ret: List[TestCase] = [] + ret.append( + TestCase( + name="dialog recall", + deployment=deployment, + streaming=streaming, + max_tokens=None, + stop=None, + messages=[ + user("my name is Anton"), + ai("nice to meet you"), + user("what's my name?"), + ], + test=lambda s: "anton" in s.lower(), + ) + ) + ret.append( TestCase( name="2+3=5", deployment=deployment, streaming=streaming, - messages=[user("2+3=?")], + max_tokens=None, + stop=None, + messages=[user("compute 2+3")], test=lambda s: "5" in s, ) ) @@ -67,7 +90,9 @@ def get_test_cases( name="empty system message", deployment=deployment, streaming=streaming, - messages=[sys(""), user("2+4=?")], + max_tokens=None, + stop=None, + messages=[sys(""), user("compute 2+4")], test=lambda s: "6" in s, ) ) @@ -81,6 +106,8 @@ def get_test_cases( name="hello", deployment=deployment, streaming=streaming, + max_tokens=None, + stop=None, messages=[user(query)], test=lambda s: "hello" in s.lower(), ) @@ -91,6 +118,8 @@ def get_test_cases( name="empty dialog", deployment=deployment, streaming=streaming, + max_tokens=1, + stop=None, messages=[], test=Exception("List of messages must not be empty"), ) @@ -101,11 +130,49 @@ def get_test_cases( name="empty user message", deployment=deployment, streaming=streaming, + max_tokens=1, + stop=None, messages=[user("")], test=lambda s: True, ) ) + ret.append( + TestCase( + name="single space user message", + deployment=deployment, + streaming=streaming, + max_tokens=1, + stop=None, + messages=[user(" ")], + test=lambda s: True, + ) + ) + + ret.append( + TestCase( + name="max tokens 1", + deployment=deployment, + streaming=streaming, + max_tokens=1, + stop=None, + messages=[user("tell me the full story of Pinocchio")], + test=lambda s: len(s.split()) <= 1, + ) + ) + + ret.append( + TestCase( + name="stop sequence", + deployment=deployment, + streaming=streaming, + max_tokens=None, + stop=["world"], + messages=[user('Reply with "hello world"')], + test=lambda s: "world" not in s.lower(), + ) + ) + return ret @@ -121,17 +188,21 @@ def get_test_cases( ids=lambda test: test.get_id(), ) async def test_chat_completion_langchain(server, test: TestCase): - model = create_model(TEST_SERVER_URL, test.deployment.value, test.streaming) + model = create_model( + TEST_SERVER_URL, test.deployment.value, test.streaming, test.max_tokens + ) if isinstance(test.test, Exception): with pytest.raises(Exception) as exc_info: - await run_model(model, test.messages, test.streaming) + await run_model(model, test.messages, test.streaming, test.stop) assert isinstance(exc_info.value, openai.error.OpenAIError) assert exc_info.value.http_status == 422 assert re.search(str(test.test), str(exc_info.value)) else: - actual_output = await run_model(model, test.messages, test.streaming) + actual_output = await run_model( + model, test.messages, test.streaming, test.stop + ) assert test.test( actual_output ), f"Failed output test, actual output: {actual_output}" diff --git a/tests/unit_tests/chat_emulation/__init__.py b/tests/unit_tests/chat_emulation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/chat_emulation/test_claude_chat_history.py b/tests/unit_tests/chat_emulation/test_claude_chat_history.py new file mode 100644 index 00000000..4c15e9ec --- /dev/null +++ b/tests/unit_tests/chat_emulation/test_claude_chat_history.py @@ -0,0 +1,123 @@ +import pytest + +from aidial_adapter_bedrock.llm.chat_emulation.claude_chat import ( + ClaudeChatHistory, +) +from aidial_adapter_bedrock.llm.chat_emulation.history import FormattedMessage +from aidial_adapter_bedrock.llm.exceptions import ValidationError +from aidial_adapter_bedrock.llm.message import ( + AIMessage, + HumanMessage, + SystemMessage, +) + + +def test_construction(): + messages = [ + SystemMessage(content="system message1"), + HumanMessage(content=" human message1 "), + AIMessage(content=" ai message1 "), + HumanMessage(content=" human message2 "), + ] + history = ClaudeChatHistory.create(messages) + + assert history.messages == [ + FormattedMessage( + text="\n\nHuman: system message1", source_message=messages[0] + ), + FormattedMessage( + text="\n\nHuman: human message1", + source_message=messages[1], + is_important=False, + ), + FormattedMessage( + text="\n\nAssistant: ai message1", + source_message=messages[2], + is_important=False, + ), + FormattedMessage( + text="\n\nHuman: human message2", source_message=messages[3] + ), + FormattedMessage(text="\n\nAssistant:"), + ] + + +def test_formatting(): + messages = [ + FormattedMessage(text="text1"), + FormattedMessage(text="text2"), + FormattedMessage(text="text3"), + ] + history = ClaudeChatHistory(messages=messages) + + prompt = history.format() + + assert prompt == "text1text2text3" + + +def test_no_trimming(): + messages = [ + FormattedMessage(text="text1"), + FormattedMessage(text="text2"), + FormattedMessage(text="text3"), + ] + history = ClaudeChatHistory(messages=messages) + + trimmed_history, discarded_messages_count = history.trim(lambda _: 1, 3) + + assert discarded_messages_count == 0 + assert trimmed_history == history + + +def test_trimming(): + messages = [ + FormattedMessage(text="text1"), + FormattedMessage(text="text2", is_important=False), + FormattedMessage(text="text3"), + FormattedMessage(text="text4", is_important=False), + FormattedMessage(text="text5"), + ] + history = ClaudeChatHistory(messages=messages) + + trimmed_history, discarded_messages_count = history.trim(lambda _: 1, 3) + + assert discarded_messages_count == 2 + assert trimmed_history.messages == [ + FormattedMessage(text="text1"), + FormattedMessage(text="text3"), + FormattedMessage(text="text5"), + ] + + +def test_prompt_is_too_big(): + messages = [ + FormattedMessage(text="text1"), + FormattedMessage(text="text2"), + FormattedMessage(text="text3"), + ] + history = ClaudeChatHistory(messages=messages) + + with pytest.raises(ValidationError) as exc_info: + history.trim(lambda _: 1, 2) + + assert ( + str(exc_info.value) + == "Prompt token size (3) exceeds prompt token limit (2)." + ) + + +def test_prompt_with_history_is_too_big(): + messages = [ + FormattedMessage(text="text1"), + FormattedMessage(text="text2", is_important=False), + FormattedMessage(text="text3"), + ] + history = ClaudeChatHistory(messages=messages) + + with pytest.raises(ValidationError) as exc_info: + history.trim(lambda _: 1, 1) + + assert ( + str(exc_info.value) + == "The token size of system messages and the last user message (2) exceeds prompt token limit (1)." + ) diff --git a/tests/unit_tests/chat_emulation/test_pseudo_chat_history.py b/tests/unit_tests/chat_emulation/test_pseudo_chat_history.py new file mode 100644 index 00000000..e3328e7c --- /dev/null +++ b/tests/unit_tests/chat_emulation/test_pseudo_chat_history.py @@ -0,0 +1,223 @@ +from typing import List + +import pytest + +from aidial_adapter_bedrock.llm.chat_emulation.history import FormattedMessage +from aidial_adapter_bedrock.llm.chat_emulation.pseudo_chat import ( + PRELUDE, + PseudoChatHistory, +) +from aidial_adapter_bedrock.llm.exceptions import ValidationError +from aidial_adapter_bedrock.llm.message import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, +) + + +def test_construction(): + messages = [ + SystemMessage(content="system message1"), + HumanMessage(content=" human message1 "), + AIMessage(content=" ai message1 "), + HumanMessage(content=" human message2 "), + ] + history = PseudoChatHistory.create(messages) + + assert history.stop_sequences == ["\n\nHuman:"] + assert history.messages == [ + FormattedMessage(text=PRELUDE), + FormattedMessage( + text="\n\nHuman: system message1", source_message=messages[0] + ), + FormattedMessage( + text="\n\nHuman: human message1", + source_message=messages[1], + is_important=False, + ), + FormattedMessage( + text="\n\nAssistant: ai message1", + source_message=messages[2], + is_important=False, + ), + FormattedMessage( + text="\n\nHuman: human message2", source_message=messages[3] + ), + FormattedMessage(text="\n\nAssistant:"), + ] + + +def test_construction_with_single_user_message(): + messages: List[BaseMessage] = [HumanMessage(content=" human message ")] + history = PseudoChatHistory.create(messages) + + assert history.stop_sequences == [] + assert history.messages == [ + FormattedMessage(text=" human message ", source_message=messages[0]) + ] + + +def test_formatting(): + messages = [ + FormattedMessage(text="text1"), + FormattedMessage(text="text2"), + FormattedMessage(text="text3"), + ] + history = PseudoChatHistory(messages=messages, stop_sequences=[]) + + prompt = history.format() + + assert prompt == "text1text2text3" + + +def test_no_trimming(): + messages = [ + FormattedMessage(text="text1"), + FormattedMessage(text="text2"), + FormattedMessage(text="text3"), + ] + history = PseudoChatHistory(messages=messages, stop_sequences=[]) + + trimmed_history, discarded_messages_count = history.trim(lambda _: 1, 3) + + assert discarded_messages_count == 0 + assert trimmed_history == history + + +def test_trimming(): + messages = [ + FormattedMessage( + text="\n\nHuman: system message1", + source_message=SystemMessage(content="system message1"), + ), + FormattedMessage(text="to_remove1", is_important=False), + FormattedMessage( + text="\n\nHuman: system message2", + source_message=SystemMessage(content="system message2"), + ), + FormattedMessage(text="to_remove2", is_important=False), + FormattedMessage( + text="\n\nHuman: query1", + source_message=HumanMessage(content="query1"), + ), + FormattedMessage(text="\n\nAssistant:"), + ] + history = PseudoChatHistory( + messages=messages, + stop_sequences=[], + ) + + trimmed_history, discarded_messages_count = history.trim(lambda _: 1, 4) + + assert discarded_messages_count == 2 + assert trimmed_history.stop_sequences == ["\n\nHuman:"] + assert trimmed_history.messages == [ + FormattedMessage(text=PRELUDE), + FormattedMessage( + text="\n\nHuman: system message1", + source_message=SystemMessage(content="system message1"), + ), + FormattedMessage( + text="\n\nHuman: system message2", + source_message=SystemMessage(content="system message2"), + ), + FormattedMessage( + text="\n\nHuman: query1", + source_message=HumanMessage(content="query1"), + ), + FormattedMessage(text="\n\nAssistant:", source_message=None), + ] + + +def test_trimming_with_one_message_left(): + messages = [ + FormattedMessage( + text="text1", + source_message=AIMessage(content="reply1"), + is_important=False, + ), + FormattedMessage( + text="text2", + source_message=HumanMessage(content="query2"), + ), + ] + history = PseudoChatHistory( + messages=messages, + stop_sequences=[], + ) + + trimmed_history, discarded_messages_count = history.trim(lambda _: 1, 1) + + assert discarded_messages_count == 1 + assert trimmed_history.stop_sequences == [] + assert trimmed_history.messages == [ + FormattedMessage( + text="query2", + source_message=HumanMessage(content="query2"), + ) + ] + + +def test_trimming_with_one_message_accepted_after_second_check(): + messages = [ + FormattedMessage( + text="text1", + source_message=AIMessage(content="reply1"), + is_important=False, + ), + FormattedMessage( + text="text2", + source_message=HumanMessage(content="query1"), + ), + ] + history = PseudoChatHistory( + messages=messages, + stop_sequences=[], + ) + + trimmed_history, discarded_messages_count = history.trim( + lambda text: 1 if text == "query1" else 2, 1 + ) + + assert discarded_messages_count == 1 + assert trimmed_history.messages == [ + FormattedMessage( + text="query1", + source_message=HumanMessage(content="query1"), + ) + ] + + +def test_prompt_is_too_big(): + messages = [ + FormattedMessage(text="text1"), + FormattedMessage(text="text2"), + FormattedMessage(text="text3"), + ] + history = PseudoChatHistory(messages=messages, stop_sequences=[]) + + with pytest.raises(ValidationError) as exc_info: + history.trim(lambda _: 1, 2) + + assert ( + str(exc_info.value) + == "Prompt token size (3) exceeds prompt token limit (2)." + ) + + +def test_prompt_with_history_is_too_big(): + messages = [ + FormattedMessage(text="text1"), + FormattedMessage(text="text2", is_important=False), + FormattedMessage(text="text3"), + ] + history = PseudoChatHistory(messages=messages, stop_sequences=[]) + + with pytest.raises(ValidationError) as exc_info: + history.trim(lambda _: 1, 1) + + assert ( + str(exc_info.value) + == "The token size of system messages and the last user message (2) exceeds prompt token limit (1)." + ) diff --git a/tests/unit_tests/chat_emulation/test_zero_memory_chat_history.py b/tests/unit_tests/chat_emulation/test_zero_memory_chat_history.py new file mode 100644 index 00000000..0a9bc20d --- /dev/null +++ b/tests/unit_tests/chat_emulation/test_zero_memory_chat_history.py @@ -0,0 +1,33 @@ +from aidial_adapter_bedrock.llm.chat_emulation.history import FormattedMessage +from aidial_adapter_bedrock.llm.chat_emulation.zero_memory_chat import ( + ZeroMemoryChatHistory, +) +from aidial_adapter_bedrock.llm.message import ( + AIMessage, + HumanMessage, + SystemMessage, +) + + +def test_construction(): + messages = [ + SystemMessage(content="system message1"), + HumanMessage(content=" human message1 "), + AIMessage(content=" ai message1 "), + HumanMessage(content=" human message2 "), + ] + history = ZeroMemoryChatHistory.create(messages) + + assert history.discarded_messages == 3 + assert history.messages == [ + FormattedMessage(text=" human message2 ", source_message=messages[3]), + ] + + +def test_formatting(): + messages = [FormattedMessage(text="text")] + history = ZeroMemoryChatHistory(messages=messages, discarded_messages=0) + + prompt = history.format() + + assert prompt == "text" diff --git a/tests/unit_tests/test_stream.py b/tests/unit_tests/test_stream.py new file mode 100644 index 00000000..1316e5cf --- /dev/null +++ b/tests/unit_tests/test_stream.py @@ -0,0 +1,143 @@ +from typing import Generator, List, Tuple + +import pytest + +import tests.utils.string as string +from aidial_adapter_bedrock.utils.stream import ( + ensure_not_empty, + lstrip, + remove_prefix, + stop_at, +) + + +def list_to_gen(xs: List[str]) -> Generator[str, None, None]: + for x in xs: + yield x + + +def gen_to_string(gen: Generator[str, None, None]) -> str: + return "".join(x for x in gen) + + +lstrip_test_cases: List[Tuple[List[str]]] = [ + ([],), + (["a"],), + ([" a"],), + ([" a", " b"],), + (["", " a"],), + ([" ", "", " a"],), + ([" \n", " ", " a"],), + ([" a\n\tb\n\t"],), + ([" \n \t \n a\n\tb\n\t"],), + (["", " \n \t \n a\n\tb\n\t"],), + ([" \n", " ", " \t \n ", " a \n\tb\n\t"],), +] + + +@pytest.mark.parametrize( + "test", + lstrip_test_cases, + ids=lambda arg: f"{arg[0]}", +) +def test_lstrip(test): + (xs,) = test + gen = lstrip(list_to_gen(xs)) + actual = gen_to_string(gen) + expected = "".join(xs).lstrip() + assert actual == expected + + +remove_prefix_test_cases: List[Tuple[str, List[str]]] = [ + ("", []), + ("a", []), + ("a", ["b"]), + ("a", ["", "", "a", "b"]), + ("a", ["b", "a"]), + ("a", ["a", "a"]), + ("a", ["aa"]), + ("a", ["aaaaa"]), + ("abc", ["!abc!"]), + ("abc", ["abcabc"]), + ("a", ["a", "b"]), + ("prefix", ["prefix:xyz"]), + ("prefix", ["prefix:prefix:xyz"]), + ("a", ["Aa"]), + ("abc", ["a"]), + ("abc", ["a", "bc"]), + ("abc", ["a", "bcd"]), + ("abc", ["a", "bc", "d"]), +] + + +@pytest.mark.parametrize( + "test", + remove_prefix_test_cases, + ids=lambda arg: f"{arg[0]}-{arg[1]}", +) +def test_remove_prefix(test): + (prefix, xs) = test + gen = remove_prefix(list_to_gen(xs), prefix) + actual = gen_to_string(gen) + expected = string.remove_prefix(prefix, "".join(xs)) + assert actual == expected + + +stop_at_test_cases: List[Tuple[str | List[str], List[str]]] = [ + ("", []), + ("", ["a", "b"]), + ("a", ["b"]), + ("a", ["ba"]), + ("a", ["b", "a"]), + ("a", ["b", "a", "c"]), + ("a", ["bac"]), + ("a", ["baca"]), + ("abc", ["zabcy"]), + ("ab", ["d", "a", "b", "c"]), + ("hello", ["? hel", "lo world", "!"]), + ("hello world", ["? hel", "lo", " wor", "ld", "!"]), + ("hello worlD", ["? hel", "lo", " wor", "ld", "!"]), + ("z", ["ab", "cd"]), + ("z", ["", "", "a", " \t ", " z ", "tt"]), + (["hello", "world"], ["Hel", "lo", " ", "world", "!"]), + (["ab", "ba"], ["abba"]), + (["ba", "ab"], ["abba"]), + (["a", "b", "c"], ["abc"]), + (["c", "b", "a"], ["abc"]), + ([], ["abc", "d", "ef"]), +] + + +@pytest.mark.parametrize( + "test", + stop_at_test_cases, + ids=lambda arg: f"{arg[0]}-{arg[1]}", +) +def test_stop_at(test): + (stop, xs) = test + stop_sequences: List[str] = [stop] if isinstance(stop, str) else stop + gen = stop_at(list_to_gen(xs), stop_sequences) + actual = gen_to_string(gen) + expected = string.stop_at(stop_sequences, "".join(xs)) + assert actual == expected + + +ensure_not_empty_test_cases: List[Tuple[str | List[str], List[str]]] = [ + ("", []), + (" ", ["", "", "a"]), + (" ", ["", "", "\t", ""]), + (" ", ["abc", "de"]), +] + + +@pytest.mark.parametrize( + "test", + ensure_not_empty_test_cases, + ids=lambda arg: f"{arg[0]}-{arg[1]}", +) +def test_ensure_not_empty(test): + (default, xs) = test + gen = ensure_not_empty(list_to_gen(xs), default) + actual = gen_to_string(gen) + expected = string.ensure_not_empty(default, "".join(xs)) + assert actual == expected diff --git a/tests/utils/llm.py b/tests/utils/llm.py index a8332622..ec763c4b 100644 --- a/tests/utils/llm.py +++ b/tests/utils/llm.py @@ -1,5 +1,5 @@ import re -from typing import List +from typing import List, Optional from langchain.callbacks.base import Callbacks from langchain.chat_models import AzureChatOpenAI @@ -28,9 +28,12 @@ def sanitize_test_name(name: str) -> str: async def run_model( - model: BaseChatModel, messages: List[BaseMessage], streaming: bool + model: BaseChatModel, + messages: List[BaseMessage], + streaming: bool, + stop: Optional[List[str]], ) -> str: - llm_result = await model.agenerate([messages]) + llm_result = await model.agenerate([messages], stop=stop) actual_usage = ( llm_result.llm_output.get("token_usage", None) @@ -45,7 +48,10 @@ async def run_model( def create_model( - base_url: str, model_id: str, streaming: bool + base_url: str, + model_id: str, + streaming: bool, + max_tokens: Optional[int], ) -> BaseChatModel: callbacks: Callbacks = [CallbackWithNewLines()] return AzureChatOpenAI( @@ -54,9 +60,13 @@ def create_model( openai_api_base=base_url, openai_api_version=DEFAULT_API_VERSION, openai_api_key="dummy_openai_api_key", - model_kwargs={"deployment_id": model_id, "api_key": "dummy_api_key"}, + model_kwargs={ + "deployment_id": model_id, + "api_key": "dummy_api_key", + }, verbose=True, streaming=streaming, + max_tokens=max_tokens, temperature=0.0, request_timeout=10, client=None, diff --git a/tests/utils/string.py b/tests/utils/string.py new file mode 100644 index 00000000..25bfc81b --- /dev/null +++ b/tests/utils/string.py @@ -0,0 +1,19 @@ +from typing import List + + +def remove_prefix(prefix: str, string: str) -> str: + if string.startswith(prefix): + return string[len(prefix) :] + return string + + +def stop_at(stop_sequences: List[str], string: str) -> str: + min_index = len(string) + for stop_sequence in stop_sequences: + if stop_sequence in string: + min_index = min(min_index, string.index(stop_sequence)) + return string[:min_index] + + +def ensure_not_empty(default: str, string: str) -> str: + return default if string == "" else string From 537ecb4fe2ccdfc9aaae04782c1eb73f8fa58b4c Mon Sep 17 00:00:00 2001 From: Anton Dubovik Date: Fri, 10 Nov 2023 17:39:56 +0000 Subject: [PATCH 2/2] fix: fixed AI21 temperature setting --- aidial_adapter_bedrock/llm/model/ai21.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/aidial_adapter_bedrock/llm/model/ai21.py b/aidial_adapter_bedrock/llm/model/ai21.py index ed114fa9..1b536afe 100644 --- a/aidial_adapter_bedrock/llm/model/ai21.py +++ b/aidial_adapter_bedrock/llm/model/ai21.py @@ -76,12 +76,10 @@ def prepare_model_kwargs(model_params: ModelParameters) -> Dict[str, Any]: model_kwargs["maxTokens"] = DEFAULT_MAX_TOKENS_AI21 if model_params.temperature is not None: - model_kwargs["temperature"] = model_params.temperature - else: - # The default AI21 temperature is 0.7. - # The default OpenAI temperature is 1.0. - # Choosing the OpenAI default since we pretend AI21 to be OpenAI. - model_kwargs["temperature"] = 1.0 + # AI21 temperature ranges from 0.0 to 1.0 + # OpenAI temperature ranges from 0.0 to 2.0 + # Thus scaling down by 2x to match the AI21 range + model_kwargs["temperature"] = model_params.temperature / 2.0 if model_params.top_p is not None: model_kwargs["topP"] = model_params.top_p