-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #22 from hmasdev/add-another-chat-model
Add Groq and Gemini to the valid LLM services
- Loading branch information
Showing
13 changed files
with
333 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import os | ||
from typing import Generator | ||
from dotenv import load_dotenv | ||
import pytest | ||
from werewolf.const import SERVICE_APIKEY_ENVVAR_MAP | ||
|
||
|
||
@pytest.fixture(scope='function') | ||
def api_keys(dummy: str = 'dummy') -> Generator[None, None, None]: | ||
for envvar in SERVICE_APIKEY_ENVVAR_MAP.values(): | ||
os.environ[envvar] = os.getenv(envvar) or dummy # type: ignore | ||
load_dotenv() | ||
yield | ||
for envvar in SERVICE_APIKEY_ENVVAR_MAP.values(): | ||
if os.getenv(envvar) == dummy: | ||
del os.environ[envvar] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import os | ||
from dotenv import load_dotenv | ||
import pytest | ||
from pytest_mock import MockerFixture | ||
from langchain_core.language_models import BaseChatModel | ||
from langchain_google_genai import ChatGoogleGenerativeAI | ||
from langchain_groq import ChatGroq | ||
from langchain_openai import ChatOpenAI | ||
from werewolf.chat_models import create_chat_model, _service2cls | ||
from werewolf.const import MODEL_SERVICE_MAP | ||
|
||
load_dotenv() | ||
|
||
name2cls: dict[str, type[BaseChatModel]] = { | ||
k: _service2cls[v] | ||
for k, v in MODEL_SERVICE_MAP.items() | ||
} | ||
|
||
|
||
@pytest.mark.usefixtures('api_keys') | ||
@pytest.mark.parametrize( | ||
'llm, expected', | ||
[(k, v) for _, (k, v) in enumerate(name2cls.items())] | ||
) | ||
def test_create_chat_model_wo_seed( | ||
llm: str, | ||
expected: type[BaseChatModel], | ||
mocker: MockerFixture, | ||
) -> None: | ||
mocker.patch(f'werewolf.chat_models.{expected.__name__}', autospec=True) # noqa | ||
# assert isinstance(create_chat_model(llm), expected) | ||
assert create_chat_model(llm).__class__.__name__ == expected.__name__ | ||
|
||
|
||
@pytest.mark.usefixtures('api_keys') | ||
@pytest.mark.parametrize( | ||
'llm, seed, expected', | ||
[(k, i, v) for i, (k, v) in enumerate(name2cls.items())] | ||
) | ||
def test_create_chat_model_w_seed( | ||
llm: str, | ||
seed: int, | ||
expected: type[BaseChatModel], | ||
mocker: MockerFixture, | ||
) -> None: | ||
cls_mock = mocker.patch(f'werewolf.chat_models.{expected.__name__}', autospec=True) # noqa | ||
# assert isinstance(create_chat_model(llm, seed), expected) | ||
assert create_chat_model(llm, seed).__class__.__name__ == expected.__name__ | ||
|
||
# TODO: fix the following assertion | ||
# cls_mock.assert_called_once_with(model=llm, seed=seed) | ||
|
||
|
||
def test_create_chat_model_w_invalid_llm() -> None: | ||
with pytest.raises(ValueError): | ||
create_chat_model('invalid') | ||
|
||
|
||
@pytest.mark.integration | ||
@pytest.mark.skipif( | ||
os.getenv("OPENAI_API_KEY") is None, | ||
reason="OPENAI_API is not set.", | ||
) | ||
@pytest.mark.parametrize( | ||
'llm', | ||
[ | ||
'gpt-4o-mini', | ||
'gpt-4', | ||
'gpt-4-turbo', | ||
'gpt-4o', | ||
'gpt-3.5-turbo', | ||
] | ||
) | ||
def test_create_chat_model_for_ChatOpenAI_integration(llm: str) -> None: | ||
assert isinstance(create_chat_model(llm), ChatOpenAI) | ||
|
||
|
||
@pytest.mark.integration | ||
@pytest.mark.skipif( | ||
os.getenv("GROQ_API_KEY") is None, | ||
reason="GROQ_API_KEY is not set.", | ||
) | ||
@pytest.mark.parametrize( | ||
'llm', | ||
[ | ||
'llama3-groq-70b-8192-tool-use-preview', | ||
'llama3-groq-8b-8192-tool-use-preview', | ||
'llama-3.1-70b-versatile', | ||
'llama-3.1-8b-instant', | ||
'llama-guard-3-8b', | ||
'llava-v1.5-7b-4096-preview', | ||
'llama3-70b-8192', | ||
'llama3-8b-8192', | ||
'mixtral-8x7b-32768', | ||
'gemma2-9b-it', | ||
'gemma2-7b-it', | ||
] | ||
) | ||
def test_create_chat_model_for_ChatGroq_integration(llm: str) -> None: | ||
assert isinstance(create_chat_model(llm), ChatGroq) | ||
|
||
|
||
@pytest.mark.integration | ||
@pytest.mark.skipif( | ||
os.getenv("GOOGLE_API_KEY") is None, | ||
reason="GOOGLE_API_KEY is not set.", | ||
) | ||
@pytest.mark.parametrize( | ||
'llm', | ||
[ | ||
'gemini-1.5-flash', | ||
'gemini-pro-vision', | ||
'gemini-pro', | ||
] | ||
) | ||
def test_create_chat_model_for_ChatGoogleGenerativeAI_integration(llm: str) -> None: # noqa | ||
assert isinstance(create_chat_model(llm), ChatGoogleGenerativeAI) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from functools import lru_cache | ||
from logging import getLogger, Logger | ||
from langchain_core.language_models import BaseChatModel | ||
# from langchain_community.chat_models import ChatPerplexity | ||
from langchain_google_genai import ChatGoogleGenerativeAI | ||
from langchain_groq import ChatGroq | ||
from langchain_openai import ChatOpenAI | ||
|
||
from .const import DEFAULT_MODEL, EChatService, MODEL_SERVICE_MAP | ||
|
||
|
||
_service2cls: dict[EChatService, type[BaseChatModel]] = { | ||
EChatService.OpenAI: ChatOpenAI, | ||
EChatService.Google: ChatGoogleGenerativeAI, | ||
EChatService.Groq: ChatGroq, | ||
} | ||
|
||
|
||
@lru_cache(maxsize=None) | ||
def create_chat_model( | ||
llm: BaseChatModel | str = DEFAULT_MODEL, | ||
seed: int | None = None, | ||
logger: Logger = getLogger(__name__), | ||
**kwargs, | ||
) -> BaseChatModel: | ||
"""Create a ChatModel instance. | ||
Args: | ||
llm (BaseChatModel | str, optional): ChatModel instance or model name. Defaults to DEFAULT_MODEL. | ||
seed (int, optional): Random seed. Defaults to None. | ||
logger (Logger, optional): Logger. Defaults to getLogger(__name__). | ||
Raises: | ||
ValueError: Unknown model name | ||
Returns: | ||
BaseChatModel: ChatModel instance | ||
Note: | ||
seed is used only when llm is a str. | ||
The same parameters return the same instance. | ||
""" # noqa | ||
llm = llm or DEFAULT_MODEL | ||
if isinstance(llm, str): | ||
try: | ||
if seed is not None: | ||
return _service2cls[MODEL_SERVICE_MAP[llm]](model=llm, seed=seed, **kwargs) # type: ignore # noqa | ||
else: | ||
return _service2cls[MODEL_SERVICE_MAP[llm]](model=llm, **kwargs) # type: ignore # noqa | ||
except TypeError: | ||
logger.warning(f'{llm} does not support seed.') | ||
return _service2cls[MODEL_SERVICE_MAP[llm]](model=llm, **kwargs) # type: ignore # noqa | ||
except KeyError: | ||
raise ValueError(f'Unknown model name: {llm}') | ||
else: | ||
return llm |
Oops, something went wrong.