Skip to content

Commit

Permalink
Merge pull request #22 from hmasdev/add-another-chat-model
Browse files Browse the repository at this point in the history
Add Groq and Gemini to the valid LLM services
  • Loading branch information
hmasdev authored Sep 7, 2024
2 parents fb5cbc6 + 1b5f1f8 commit e17e2ad
Show file tree
Hide file tree
Showing 13 changed files with 333 additions and 42 deletions.
16 changes: 12 additions & 4 deletions README.md.j2
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,13 @@

## Requirements

- OpenAI API Key
- [https://platform.openai.com/api-keys](https://platform.openai.com/api-keys)
- Get your API key
- OpenAI API Key
- [https://platform.openai.com/api-keys](https://platform.openai.com/api-keys)
- Groq API Key
- [https://console.groq.com/keys](https://console.groq.com/keys)
- Gemini API Key
- [https://aistudio.google.com/app/apikey](https://aistudio.google.com/app/apikey)

- (optional) `docker compose`
- (optional) python >= 3.10
Expand All @@ -23,10 +28,13 @@ Note that either of `docker compose` or `python` is required.
### Preparation

1. Create `.env` file
2. Set `OPENAI_API_KEY`:
2. Set `OPENAI_API_KEY`, `GROQ_API_KEY` or `GOOGLE_API_KEY` in the `.env` file as follows:

```text
OPENAI_API_KEY=HERE_IS_YOUR_OPENAI_API_KEY
OPENAI_API_KEY=HERE_IS_YOUR_API_KEY
GROQ_API_KEY=HERE_IS_YOUR_API_KEY
GOOGLE_API_KEY=HERE_IS_YOUR_API_KEY

```

3. If you don't use `docker` but `python` in your machine, create a virtual environment and install libraries manually:
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ requires-python = ">=3.10"
dependencies = [
"click>=8.1.7",
"langchain>=0.2.0",
"langchain-groq>=0.1.9",
"langchain-google-genai>=1.0.10",
"langchain-openai",
"pyautogen==0.2.16",
"pyautogen[gemini,groq]>=0.2.32",
"python-dotenv>=1.0.1",
]
authors = [{ name = "hmasdev" }]
Expand Down
16 changes: 16 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os
from typing import Generator
from dotenv import load_dotenv
import pytest
from werewolf.const import SERVICE_APIKEY_ENVVAR_MAP


@pytest.fixture(scope='function')
def api_keys(dummy: str = 'dummy') -> Generator[None, None, None]:
for envvar in SERVICE_APIKEY_ENVVAR_MAP.values():
os.environ[envvar] = os.getenv(envvar) or dummy # type: ignore
load_dotenv()
yield
for envvar in SERVICE_APIKEY_ENVVAR_MAP.values():
if os.getenv(envvar) == dummy:
del os.environ[envvar]
25 changes: 14 additions & 11 deletions tests/game_master/test_default_game_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from dataclasses import asdict
import os
import autogen
from dotenv import load_dotenv
from flaky import flaky
from langchain_openai import ChatOpenAI
from langchain_core.language_models import BaseChatModel
import pytest
from pytest_mock import MockerFixture
from werewolf.config import GameConfig
Expand All @@ -15,6 +16,8 @@
)
from werewolf.game_master.default_game_master import DefaultGameMaster

load_dotenv()


@pytest.fixture
def game_config() -> GameConfig:
Expand Down Expand Up @@ -157,10 +160,10 @@ def test_DefaultGameMaster__clean_name(
input_name = ' Player0 '
expected = 'Player0'
llm_output = expected
llm_mock = mocker.MagicMock(spec=ChatOpenAI)
llm_mock = mocker.MagicMock(spec=BaseChatModel)
llm_mock.invoke.return_value = namedtuple('BaseMessage', ['content'])(llm_output) # noqa
create_chat_openai_model_mock = mocker.patch(
'werewolf.game_master.default_game_master.create_chat_openai_model',
create_chat_model_mock = mocker.patch(
'werewolf.game_master.default_game_master.create_chat_model',
return_value=llm_mock,
autospec=True,
)
Expand All @@ -174,7 +177,7 @@ def test_DefaultGameMaster__clean_name(
actual = master._clean_name(input_name, question='Who do you think should be excluded from the game?') # noqa
# assert
assert actual == expected
create_chat_openai_model_mock.assert_called_once()
create_chat_model_mock.assert_called_once()


@pytest.mark.parametrize(
Expand All @@ -190,10 +193,10 @@ def test_DefaultGameMaster__clean_name_all_fail(
input_name = ''
expected = 'None'
llm_output = 'dummy'
llm_mock = mocker.MagicMock(spec=ChatOpenAI)
llm_mock = mocker.MagicMock(spec=BaseChatModel)
llm_mock.invoke.return_value = namedtuple('BaseMessage', ['content'])(llm_output) # noqa
_ = mocker.patch(
'werewolf.game_master.default_game_master.create_chat_openai_model',
'werewolf.game_master.default_game_master.create_chat_model',
return_value=llm_mock,
autospec=True,
)
Expand Down Expand Up @@ -223,13 +226,13 @@ def test_DefaultGameMaster__clean_name_n_fails(
expected = 'Player0'
llm_output_fail = 'dummy'
llm_output = expected
llm_mock = mocker.MagicMock(spec=ChatOpenAI)
llm_mock = mocker.MagicMock(spec=BaseChatModel)
llm_mock.invoke.side_effect = [
namedtuple('BaseMessage', ['content'])(llm_output_fail)
for _ in range(n_fails)
] + [namedtuple('BaseMessage', ['content'])(llm_output)]
_ = mocker.patch(
'werewolf.game_master.default_game_master.create_chat_openai_model',
'werewolf.game_master.default_game_master.create_chat_model',
return_value=llm_mock,
autospec=True,
)
Expand Down Expand Up @@ -391,10 +394,10 @@ def test_DefaultGameMaster_ask_to_vote_without_last_message_content(
# init
expected: str = 'None'
llm_output = 'dummy'
llm_mock = mocker.MagicMock(spec=ChatOpenAI)
llm_mock = mocker.MagicMock(spec=BaseChatModel)
llm_mock.invoke.return_value = namedtuple('BaseMessage', ['content'])(llm_output) # noqa
_ = mocker.patch(
'werewolf.game_master.default_game_master.create_chat_openai_model',
'werewolf.game_master.default_game_master.create_chat_model',
return_value=llm_mock,
autospec=True,
)
Expand Down
117 changes: 117 additions & 0 deletions tests/test_chat_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import os
from dotenv import load_dotenv
import pytest
from pytest_mock import MockerFixture
from langchain_core.language_models import BaseChatModel
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from werewolf.chat_models import create_chat_model, _service2cls
from werewolf.const import MODEL_SERVICE_MAP

load_dotenv()

name2cls: dict[str, type[BaseChatModel]] = {
k: _service2cls[v]
for k, v in MODEL_SERVICE_MAP.items()
}


@pytest.mark.usefixtures('api_keys')
@pytest.mark.parametrize(
'llm, expected',
[(k, v) for _, (k, v) in enumerate(name2cls.items())]
)
def test_create_chat_model_wo_seed(
llm: str,
expected: type[BaseChatModel],
mocker: MockerFixture,
) -> None:
mocker.patch(f'werewolf.chat_models.{expected.__name__}', autospec=True) # noqa
# assert isinstance(create_chat_model(llm), expected)
assert create_chat_model(llm).__class__.__name__ == expected.__name__


@pytest.mark.usefixtures('api_keys')
@pytest.mark.parametrize(
'llm, seed, expected',
[(k, i, v) for i, (k, v) in enumerate(name2cls.items())]
)
def test_create_chat_model_w_seed(
llm: str,
seed: int,
expected: type[BaseChatModel],
mocker: MockerFixture,
) -> None:
cls_mock = mocker.patch(f'werewolf.chat_models.{expected.__name__}', autospec=True) # noqa
# assert isinstance(create_chat_model(llm, seed), expected)
assert create_chat_model(llm, seed).__class__.__name__ == expected.__name__

# TODO: fix the following assertion
# cls_mock.assert_called_once_with(model=llm, seed=seed)


def test_create_chat_model_w_invalid_llm() -> None:
with pytest.raises(ValueError):
create_chat_model('invalid')


@pytest.mark.integration
@pytest.mark.skipif(
os.getenv("OPENAI_API_KEY") is None,
reason="OPENAI_API is not set.",
)
@pytest.mark.parametrize(
'llm',
[
'gpt-4o-mini',
'gpt-4',
'gpt-4-turbo',
'gpt-4o',
'gpt-3.5-turbo',
]
)
def test_create_chat_model_for_ChatOpenAI_integration(llm: str) -> None:
assert isinstance(create_chat_model(llm), ChatOpenAI)


@pytest.mark.integration
@pytest.mark.skipif(
os.getenv("GROQ_API_KEY") is None,
reason="GROQ_API_KEY is not set.",
)
@pytest.mark.parametrize(
'llm',
[
'llama3-groq-70b-8192-tool-use-preview',
'llama3-groq-8b-8192-tool-use-preview',
'llama-3.1-70b-versatile',
'llama-3.1-8b-instant',
'llama-guard-3-8b',
'llava-v1.5-7b-4096-preview',
'llama3-70b-8192',
'llama3-8b-8192',
'mixtral-8x7b-32768',
'gemma2-9b-it',
'gemma2-7b-it',
]
)
def test_create_chat_model_for_ChatGroq_integration(llm: str) -> None:
assert isinstance(create_chat_model(llm), ChatGroq)


@pytest.mark.integration
@pytest.mark.skipif(
os.getenv("GOOGLE_API_KEY") is None,
reason="GOOGLE_API_KEY is not set.",
)
@pytest.mark.parametrize(
'llm',
[
'gemini-1.5-flash',
'gemini-pro-vision',
'gemini-pro',
]
)
def test_create_chat_model_for_ChatGoogleGenerativeAI_integration(llm: str) -> None: # noqa
assert isinstance(create_chat_model(llm), ChatGoogleGenerativeAI)
22 changes: 11 additions & 11 deletions tests/utils/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,25 @@
load_dotenv()


def test_create_chat_openai_model_with_none(mocker: MockerFixture):
def test_create_chat_openai_model_with_none(mocker: MockerFixture) -> None: # noqa
coai_mock = mocker.patch("werewolf.utils.openai.ChatOpenAI")
create_chat_openai_model()
create_chat_openai_model() # type: ignore
coai_mock.assert_called_once_with(model=DEFAULT_MODEL, seed=None)


@pytest.mark.parametrize("model_name", ["gpt-3.5-turbo", "gpt-4o-mini"])
def test_create_chat_openai_model_with_str(
model_name: str,
mocker: MockerFixture,
):
) -> None:
coai_mock = mocker.patch("werewolf.utils.openai.ChatOpenAI")
create_chat_openai_model(model_name)
create_chat_openai_model(model_name) # type: ignore
coai_mock.assert_called_once_with(model=model_name, seed=None)


def test_create_chat_openai_model_with_instance(mocker: MockerFixture):
def test_create_chat_openai_model_with_instance(mocker: MockerFixture) -> None: # noqa
llm = mocker.MagicMock(spec=ChatOpenAI)
actual = create_chat_openai_model(llm)
actual: ChatOpenAI = create_chat_openai_model(llm)
assert actual is llm


Expand All @@ -43,13 +43,13 @@ def test_create_chat_openai_model_return_same_instance_for_same_input(
llm: str,
seed: int,
mocker: MockerFixture,
):
) -> None:
ChatOpenAI_mock = mocker.patch(
"werewolf.utils.openai.ChatOpenAI",
return_value=mocker.MagicMock(spec=ChatOpenAI),
)
actual1 = create_chat_openai_model(llm, seed)
actual2 = create_chat_openai_model(llm, seed)
actual1: ChatOpenAI = create_chat_openai_model(llm, seed) # type: ignore
actual2: ChatOpenAI = create_chat_openai_model(llm, seed) # type: ignore
assert actual1 is actual2
ChatOpenAI_mock.assert_called_once_with(model=llm, seed=seed)

Expand All @@ -66,6 +66,6 @@ def test_create_chat_openai_model_return_same_instance_for_same_input(
None,
],
)
def test_create_chat_openai_model_with_real_instance(llm: str | None):
actual = create_chat_openai_model(llm)
def test_create_chat_openai_model_with_real_instance(llm: str | None) -> None:
actual: ChatOpenAI = create_chat_openai_model(llm) # type: ignore
assert isinstance(actual, ChatOpenAI)
56 changes: 56 additions & 0 deletions werewolf/chat_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from functools import lru_cache
from logging import getLogger, Logger
from langchain_core.language_models import BaseChatModel
# from langchain_community.chat_models import ChatPerplexity
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI

from .const import DEFAULT_MODEL, EChatService, MODEL_SERVICE_MAP


_service2cls: dict[EChatService, type[BaseChatModel]] = {
EChatService.OpenAI: ChatOpenAI,
EChatService.Google: ChatGoogleGenerativeAI,
EChatService.Groq: ChatGroq,
}


@lru_cache(maxsize=None)
def create_chat_model(
llm: BaseChatModel | str = DEFAULT_MODEL,
seed: int | None = None,
logger: Logger = getLogger(__name__),
**kwargs,
) -> BaseChatModel:
"""Create a ChatModel instance.
Args:
llm (BaseChatModel | str, optional): ChatModel instance or model name. Defaults to DEFAULT_MODEL.
seed (int, optional): Random seed. Defaults to None.
logger (Logger, optional): Logger. Defaults to getLogger(__name__).
Raises:
ValueError: Unknown model name
Returns:
BaseChatModel: ChatModel instance
Note:
seed is used only when llm is a str.
The same parameters return the same instance.
""" # noqa
llm = llm or DEFAULT_MODEL
if isinstance(llm, str):
try:
if seed is not None:
return _service2cls[MODEL_SERVICE_MAP[llm]](model=llm, seed=seed, **kwargs) # type: ignore # noqa
else:
return _service2cls[MODEL_SERVICE_MAP[llm]](model=llm, **kwargs) # type: ignore # noqa
except TypeError:
logger.warning(f'{llm} does not support seed.')
return _service2cls[MODEL_SERVICE_MAP[llm]](model=llm, **kwargs) # type: ignore # noqa
except KeyError:
raise ValueError(f'Unknown model name: {llm}')
else:
return llm
Loading

0 comments on commit e17e2ad

Please sign in to comment.