From df27c53fe1453b6045c5791ef5db8757cfd91b75 Mon Sep 17 00:00:00 2001 From: Pierce Kelaita Date: Wed, 24 Jul 2024 22:00:10 -0700 Subject: [PATCH] [octoai] add chat models --- README.md | 46 ++++++----- l2m2/_internal/http.py | 2 +- l2m2/client/base_llm_client.py | 20 ++++- l2m2/exceptions.py | 6 ++ l2m2/model_info.py | 96 ++++++++++++++++++++++- tests/l2m2/_internal/test_http.py | 49 ++++++++++-- tests/l2m2/client/test_base_llm_client.py | 16 ++-- 7 files changed, 194 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index 944a578..a395cd4 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ # L2M2: A Simple Python LLM Manager 💬👍 -[![Tests](https://github.com/pkelaita/l2m2/actions/workflows/tests.yml/badge.svg?timestamp=1721868974)](https://github.com/pkelaita/l2m2/actions/workflows/tests.yml) [![codecov](https://codecov.io/github/pkelaita/l2m2/graph/badge.svg?token=UWIB0L9PR8)](https://codecov.io/github/pkelaita/l2m2) [![PyPI version](https://badge.fury.io/py/l2m2.svg?timestamp=1721868974)](https://badge.fury.io/py/l2m2) +[![Tests](https://github.com/pkelaita/l2m2/actions/workflows/tests.yml/badge.svg?timestamp=1721884488)](https://github.com/pkelaita/l2m2/actions/workflows/tests.yml) [![codecov](https://codecov.io/github/pkelaita/l2m2/graph/badge.svg?token=UWIB0L9PR8)](https://codecov.io/github/pkelaita/l2m2) [![PyPI version](https://badge.fury.io/py/l2m2.svg?timestamp=1721884488)](https://badge.fury.io/py/l2m2) **L2M2** ("LLM Manager" → "LLMM" → "L2M2") is a tiny and very simple LLM manager for Python that exposes lots of models through a unified API. This is useful for evaluation, demos, production applications etc. that need to easily be model-agnostic. ### Features -- 17 supported models (see below) – regularly updated and with more on the way +- 21 supported models (see below) – regularly updated and with more on the way - Session chat memory – even across multiple models - JSON mode - Prompt loading tools @@ -23,25 +23,29 @@ L2M2 currently supports the following models: -| Model Name | Provider(s) | Model Version(s) | -| ------------------- | -------------------------------------------------------------------- | ------------------------------------------------------------------- | -| `gpt-4o` | [OpenAI](https://openai.com/product) | `gpt-4o-2024-05-13` | -| `gpt-4o-mini` | [OpenAI](https://openai.com/product) | `gpt-4o-mini-2024-07-18` | -| `gpt-4-turbo` | [OpenAI](https://openai.com/product) | `gpt-4-turbo-2024-04-09` | -| `gpt-3.5-turbo` | [OpenAI](https://openai.com/product) | `gpt-3.5-turbo-0125` | -| `gemini-1.5-pro` | [Google](https://ai.google.dev/) | `gemini-1.5-pro` | -| `gemini-1.0-pro` | [Google](https://ai.google.dev/) | `gemini-1.0-pro` | -| `claude-3.5-sonnet` | [Anthropic](https://www.anthropic.com/api) | `claude-3-5-sonnet-20240620` | -| `claude-3-opus` | [Anthropic](https://www.anthropic.com/api) | `claude-3-opus-20240229` | -| `claude-3-sonnet` | [Anthropic](https://www.anthropic.com/api) | `claude-3-sonnet-20240229` | -| `claude-3-haiku` | [Anthropic](https://www.anthropic.com/api) | `claude-3-haiku-20240307` | -| `command-r` | [Cohere](https://docs.cohere.com/) | `command-r` | -| `command-r-plus` | [Cohere](https://docs.cohere.com/) | `command-r-plus` | -| `mixtral-8x7b` | [Groq](https://wow.groq.com/) | `mixtral-8x7b-32768` | -| `gemma-7b` | [Groq](https://wow.groq.com/) | `gemma-7b-it` | -| `llama3-8b` | [Groq](https://wow.groq.com/), [Replicate](https://replicate.com/) | `llama3-8b-8192`, `meta/meta-llama-3-8b-instruct` | -| `llama3-70b` | [Groq](https://wow.groq.com/), [Replicate](https://replicate.com/) | `llama3-70b-8192`, `meta/meta-llama-3-70b-instruct` | -| `llama3.1-405b` | [Replicate](https://replicate.com/), [OctoAI](https://octoai.cloud/) | `meta/meta-llama-3.1-405b-instruct`, `meta-llama-3.1-405b-instruct` | +| Model Name | Provider(s) | Model Version(s) | +| ------------------- | --------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------- | +| `gpt-4o` | [OpenAI](https://openai.com/product) | `gpt-4o-2024-05-13` | +| `gpt-4o-mini` | [OpenAI](https://openai.com/product) | `gpt-4o-mini-2024-07-18` | +| `gpt-4-turbo` | [OpenAI](https://openai.com/product) | `gpt-4-turbo-2024-04-09` | +| `gpt-3.5-turbo` | [OpenAI](https://openai.com/product) | `gpt-3.5-turbo-0125` | +| `gemini-1.5-pro` | [Google](https://ai.google.dev/) | `gemini-1.5-pro` | +| `gemini-1.0-pro` | [Google](https://ai.google.dev/) | `gemini-1.0-pro` | +| `claude-3.5-sonnet` | [Anthropic](https://www.anthropic.com/api) | `claude-3-5-sonnet-20240620` | +| `claude-3-opus` | [Anthropic](https://www.anthropic.com/api) | `claude-3-opus-20240229` | +| `claude-3-sonnet` | [Anthropic](https://www.anthropic.com/api) | `claude-3-sonnet-20240229` | +| `claude-3-haiku` | [Anthropic](https://www.anthropic.com/api) | `claude-3-haiku-20240307` | +| `command-r` | [Cohere](https://docs.cohere.com/) | `command-r` | +| `command-r-plus` | [Cohere](https://docs.cohere.com/) | `command-r-plus` | +| `mistral-7b` | [OctoAI](https://octoai.cloud/) | `mistral-7b-instruct` | +| `mixtral-8x7b` | [Groq](https://wow.groq.com/), [OctoAI](https://octoai.cloud/) | `mixtral-8x7b-32768`, `mixtral-8x7b-instruct` | +| `mixtral-8x22b` | [OctoAI](https://octoai.cloud/) | `mixtral-8x22b-instruct` | +| `gemma-7b` | [Groq](https://wow.groq.com/) | `gemma-7b-it` | +| `llama3-8b` | [Groq](https://wow.groq.com/), [Replicate](https://replicate.com/) | `llama3-8b-8192`, `meta/meta-llama-3-8b-instruct` | +| `llama3-70b` | [Groq](https://wow.groq.com/), [Replicate](https://replicate.com/), [OctoAI](https://octoai.cloud/) | `llama3-70b-8192`, `meta/meta-llama-3-70b-instruct`, `meta-llama-3-70b-instruct` | +| `llama3.1-8b` | [OctoAI](https://octoai.cloud/) | `meta-llama-3.1-8b-instruct` | +| `llama3.1-70b` | [OctoAI](https://octoai.cloud/) | `meta-llama-3.1-70b-instruct` | +| `llama3.1-405b` | [Replicate](https://replicate.com/), [OctoAI](https://octoai.cloud/) | `meta/meta-llama-3.1-405b-instruct`, `meta-llama-3.1-405b-instruct` | diff --git a/l2m2/_internal/http.py b/l2m2/_internal/http.py index 5d99613..1f5694e 100644 --- a/l2m2/_internal/http.py +++ b/l2m2/_internal/http.py @@ -39,10 +39,10 @@ async def _handle_replicate_201( async def llm_post( client: httpx.AsyncClient, provider: str, + model_id: str, api_key: str, data: Dict[str, Any], timeout: Optional[int], - model_id: Optional[str] = None, ) -> Any: endpoint = PROVIDER_INFO[provider]["endpoint"] if API_KEY in endpoint: diff --git a/l2m2/client/base_llm_client.py b/l2m2/client/base_llm_client.py index ea3bf87..07a7afe 100644 --- a/l2m2/client/base_llm_client.py +++ b/l2m2/client/base_llm_client.py @@ -22,6 +22,7 @@ get_extra_message, run_json_strats_out, ) +from l2m2.exceptions import LLMOperationError from l2m2._internal.http import llm_post @@ -501,6 +502,7 @@ async def _call_openai( result = await llm_post( client=self.httpx_client, provider="openai", + model_id=model_id, api_key=self.api_keys["openai"], data={"model": model_id, "messages": messages, **params}, timeout=timeout, @@ -532,6 +534,7 @@ async def _call_anthropic( result = await llm_post( client=self.httpx_client, provider="anthropic", + model_id=model_id, api_key=self.api_keys["anthropic"], data={"model": model_id, "messages": messages, **params}, timeout=timeout, @@ -564,6 +567,7 @@ async def _call_cohere( result = await llm_post( client=self.httpx_client, provider="cohere", + model_id=model_id, api_key=self.api_keys["cohere"], data={"model": model_id, "message": prompt, **params}, timeout=timeout, @@ -595,6 +599,7 @@ async def _call_groq( result = await llm_post( client=self.httpx_client, provider="groq", + model_id=model_id, api_key=self.api_keys["groq"], data={"model": model_id, "messages": messages, **params}, timeout=timeout, @@ -633,10 +638,10 @@ async def _call_google( result = await llm_post( client=self.httpx_client, provider="google", + model_id=model_id, api_key=self.api_keys["google"], data=data, timeout=timeout, - model_id=model_id, ) result = result["candidates"][0] @@ -657,12 +662,12 @@ async def _call_replicate( json_mode_strategy: JsonModeStrategy, ) -> str: if isinstance(self.memory, ChatMemory): - raise ValueError( + raise LLMOperationError( "Chat memory is not supported with Replicate." + " Try using Groq, or using ExternalMemory instead." ) if json_mode_strategy.strategy_name == StrategyName.PREPEND: - raise ValueError( + raise LLMOperationError( "JsonModeStrategy.prepend() is not supported with Replicate." + " Try using Groq, or using JsonModeStrategy.strip() instead." ) @@ -673,10 +678,10 @@ async def _call_replicate( result = await llm_post( client=self.httpx_client, provider="replicate", + model_id=model_id, api_key=self.api_keys["replicate"], data={"input": {"prompt": prompt, **params}}, timeout=timeout, - model_id=model_id, ) return "".join(result["output"]) @@ -690,6 +695,12 @@ async def _call_octoai( json_mode: bool, json_mode_strategy: JsonModeStrategy, ) -> str: + if isinstance(self.memory, ChatMemory) and model_id == "mixtral-8x22b-instruct": + raise LLMOperationError( + "Chat memory is not supported with mixtral-8x22b via OctoAI. Try using" + + " ExternalMemory instead, or ChatMemory with a different model/provider." + ) + messages = [] if system_prompt is not None: messages.append({"role": "system", "content": system_prompt}) @@ -705,6 +716,7 @@ async def _call_octoai( result = await llm_post( client=self.httpx_client, provider="octoai", + model_id=model_id, api_key=self.api_keys["octoai"], data={"model": model_id, "messages": messages, **params}, timeout=timeout, diff --git a/l2m2/exceptions.py b/l2m2/exceptions.py index 504987e..20bcf43 100644 --- a/l2m2/exceptions.py +++ b/l2m2/exceptions.py @@ -8,3 +8,9 @@ class LLMRateLimitError(Exception): """Raised when a request to an LLM provider API is rate limited.""" pass + + +class LLMOperationError(Exception): + """Raised when a model does not support a particular feature or mode.""" + + pass diff --git a/l2m2/model_info.py b/l2m2/model_info.py index bcdbe6f..edb4c70 100644 --- a/l2m2/model_info.py +++ b/l2m2/model_info.py @@ -310,6 +310,22 @@ class ModelEntry(TypedDict): "extras": {}, }, }, + "mistral-7b": { + "octoai": { + "model_id": "mistral-7b-instruct", + "params": { + "temperature": { + "default": PROVIDER_DEFAULT, + "max": 2.0, + }, + "max_tokens": { + "default": PROVIDER_DEFAULT, + "max": INF, + }, + }, + "extras": {}, + }, + }, "mixtral-8x7b": { "groq": { "model_id": "mixtral-8x7b-32768", @@ -325,6 +341,36 @@ class ModelEntry(TypedDict): }, "extras": {}, }, + "octoai": { + "model_id": "mixtral-8x7b-instruct", + "params": { + "temperature": { + "default": PROVIDER_DEFAULT, + "max": 2.0, + }, + "max_tokens": { + "default": PROVIDER_DEFAULT, + "max": INF, + }, + }, + "extras": {}, + }, + }, + "mixtral-8x22b": { + "octoai": { + "model_id": "mixtral-8x22b-instruct", + "params": { + "temperature": { + "default": PROVIDER_DEFAULT, + "max": 2.0, + }, + "max_tokens": { + "default": PROVIDER_DEFAULT, + "max": INF, + }, + }, + "extras": {}, + }, }, "gemma-7b": { "groq": { @@ -348,7 +394,7 @@ class ModelEntry(TypedDict): "params": { "temperature": { "default": PROVIDER_DEFAULT, - "max": 2, + "max": 2.0, }, "max_tokens": { "default": PROVIDER_DEFAULT, @@ -379,7 +425,7 @@ class ModelEntry(TypedDict): "params": { "temperature": { "default": PROVIDER_DEFAULT, - "max": 2, + "max": 2.0, }, "max_tokens": { "default": PROVIDER_DEFAULT, @@ -403,6 +449,52 @@ class ModelEntry(TypedDict): }, "extras": {}, }, + "octoai": { + "model_id": "meta-llama-3-70b-instruct", + "params": { + "temperature": { + "default": PROVIDER_DEFAULT, + "max": 2.0, + }, + "max_tokens": { + "default": PROVIDER_DEFAULT, + "max": INF, + }, + }, + "extras": {}, + }, + }, + "llama3.1-8b": { + "octoai": { + "model_id": "meta-llama-3.1-8b-instruct", + "params": { + "temperature": { + "default": PROVIDER_DEFAULT, + "max": 2.0, + }, + "max_tokens": { + "default": PROVIDER_DEFAULT, + "max": INF, + }, + }, + "extras": {}, + }, + }, + "llama3.1-70b": { + "octoai": { + "model_id": "meta-llama-3.1-70b-instruct", + "params": { + "temperature": { + "default": PROVIDER_DEFAULT, + "max": 2.0, + }, + "max_tokens": { + "default": PROVIDER_DEFAULT, + "max": INF, + }, + }, + "extras": {}, + }, }, "llama3.1-405b": { "replicate": { diff --git a/tests/l2m2/_internal/test_http.py b/tests/l2m2/_internal/test_http.py index 972cc85..dd1c89d 100644 --- a/tests/l2m2/_internal/test_http.py +++ b/tests/l2m2/_internal/test_http.py @@ -3,7 +3,7 @@ import respx from unittest.mock import patch -from l2m2.exceptions import LLMTimeoutError +from l2m2.exceptions import LLMRateLimitError, LLMTimeoutError from l2m2.model_info import API_KEY, MODEL_ID from l2m2._internal.http import ( _get_headers, @@ -118,7 +118,14 @@ async def _test_generic_llm_post(provider): return_value=httpx.Response(200, json=expected_response) ) async with httpx.AsyncClient() as client: - result = await llm_post(client, provider, api_key, data, 10, model_id) + result = await llm_post( + client=client, + provider=provider, + model_id=model_id, + api_key=api_key, + data=data, + timeout=10, + ) assert result == expected_response await _test_generic_llm_post("test_provider") @@ -153,7 +160,14 @@ async def test_llm_post_replicate(): ) async with httpx.AsyncClient() as client: - result = await llm_post(client, provider, api_key, data, 10) + result = await llm_post( + client=client, + provider=provider, + model_id="fake_model_id", + api_key=api_key, + data=data, + timeout=10, + ) assert result == mock_success_response @@ -236,7 +250,14 @@ async def test_llm_post_failure(): respx.post(endpoint).mock(return_value=httpx.Response(400, text="Bad Request")) async with httpx.AsyncClient() as client: with pytest.raises(Exception): - await llm_post(client, provider, api_key, data, model_id) + await llm_post( + client=client, + provider=provider, + model_id=model_id, + api_key=api_key, + data=data, + timeout=10, + ) @pytest.mark.asyncio @@ -258,7 +279,14 @@ async def test_llm_post_timeout(): respx.post(endpoint).mock(side_effect=httpx.ReadTimeout) async with httpx.AsyncClient() as client: with pytest.raises(LLMTimeoutError): - await llm_post(client, provider, api_key, data, timeout, model_id) + await llm_post( + client=client, + provider=provider, + model_id=model_id, + api_key=api_key, + data=data, + timeout=timeout, + ) @pytest.mark.asyncio @@ -280,5 +308,12 @@ async def test_llm_post_rate_limit_error(): return_value=httpx.Response(429, text="Rate Limit Exceeded") ) async with httpx.AsyncClient() as client: - with pytest.raises(Exception): - await llm_post(client, provider, api_key, data, model_id) + with pytest.raises(LLMRateLimitError): + await llm_post( + client=client, + provider=provider, + model_id=model_id, + api_key=api_key, + data=data, + timeout=10, + ) diff --git a/tests/l2m2/client/test_base_llm_client.py b/tests/l2m2/client/test_base_llm_client.py index 63f5810..a1c05e3 100644 --- a/tests/l2m2/client/test_base_llm_client.py +++ b/tests/l2m2/client/test_base_llm_client.py @@ -10,11 +10,18 @@ ) from l2m2.client.base_llm_client import BaseLLMClient from l2m2.tools import JsonModeStrategy +from l2m2.exceptions import LLMOperationError LLM_POST_PATH = "l2m2.client.base_llm_client.llm_post" GET_EXTRA_MESSAGE_PATH = "l2m2.client.base_llm_client.get_extra_message" CALL_BASE_PATH = "l2m2.client.base_llm_client.BaseLLMClient._call_" +# All of the model/provider pairs which don't support ChatMemory +CHAT_MEMORY_UNSUPPORTED_MODELS = { + "octoai": "mixtral-8x22b", + "replicate": "llama3-8b", # Applies to all models via Replicate +} + # -- Fixtures -- # @@ -480,12 +487,9 @@ def test_chat_memory_errors(llm_client): @pytest.mark.asyncio async def test_chat_memory_unsupported_provider(llm_client_mem_chat): - unsupported_providers = { - "replicate": "llama3-8b", - } - for provider, model in unsupported_providers.items(): + for provider, model in CHAT_MEMORY_UNSUPPORTED_MODELS.items(): llm_client_mem_chat.add_provider(provider, "fake-api-key") - with pytest.raises(ValueError): + with pytest.raises(LLMOperationError): await llm_client_mem_chat.call(prompt="Hello", model=model) @@ -680,7 +684,7 @@ async def test_json_mode_strategy_prepend_groq(mock_call_groq, llm_client): @pytest.mark.asyncio async def test_json_mode_strategy_prepend_replicate_throws_error(llm_client): llm_client.add_provider("replicate", "fake-api-key") - with pytest.raises(ValueError): + with pytest.raises(LLMOperationError): await llm_client.call( prompt="Hello", model="llama3-8b",