From 642468fbd26536d4abcc67a68cbba58423ba154f Mon Sep 17 00:00:00 2001 From: hmasdev Date: Thu, 5 Sep 2024 22:42:10 +0900 Subject: [PATCH 1/7] chore: Update dependencies and imports for langchain_core --- tests/game_master/test_default_game_master.py | 10 +++++----- werewolf/game.py | 4 ++-- werewolf/game_master/default_game_master.py | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/game_master/test_default_game_master.py b/tests/game_master/test_default_game_master.py index 7d2375b..7924192 100644 --- a/tests/game_master/test_default_game_master.py +++ b/tests/game_master/test_default_game_master.py @@ -3,7 +3,7 @@ import os import autogen from flaky import flaky -from langchain_openai import ChatOpenAI +from langchain_core.language_models import BaseChatModel import pytest from pytest_mock import MockerFixture from werewolf.config import GameConfig @@ -157,7 +157,7 @@ def test_DefaultGameMaster__clean_name( input_name = ' Player0 ' expected = 'Player0' llm_output = expected - llm_mock = mocker.MagicMock(spec=ChatOpenAI) + llm_mock = mocker.MagicMock(spec=BaseChatModel) llm_mock.invoke.return_value = namedtuple('BaseMessage', ['content'])(llm_output) # noqa create_chat_openai_model_mock = mocker.patch( 'werewolf.game_master.default_game_master.create_chat_openai_model', @@ -190,7 +190,7 @@ def test_DefaultGameMaster__clean_name_all_fail( input_name = '' expected = 'None' llm_output = 'dummy' - llm_mock = mocker.MagicMock(spec=ChatOpenAI) + llm_mock = mocker.MagicMock(spec=BaseChatModel) llm_mock.invoke.return_value = namedtuple('BaseMessage', ['content'])(llm_output) # noqa _ = mocker.patch( 'werewolf.game_master.default_game_master.create_chat_openai_model', @@ -223,7 +223,7 @@ def test_DefaultGameMaster__clean_name_n_fails( expected = 'Player0' llm_output_fail = 'dummy' llm_output = expected - llm_mock = mocker.MagicMock(spec=ChatOpenAI) + llm_mock = mocker.MagicMock(spec=BaseChatModel) llm_mock.invoke.side_effect = [ namedtuple('BaseMessage', ['content'])(llm_output_fail) for _ in range(n_fails) @@ -391,7 +391,7 @@ def test_DefaultGameMaster_ask_to_vote_without_last_message_content( # init expected: str = 'None' llm_output = 'dummy' - llm_mock = mocker.MagicMock(spec=ChatOpenAI) + llm_mock = mocker.MagicMock(spec=BaseChatModel) llm_mock.invoke.return_value = namedtuple('BaseMessage', ['content'])(llm_output) # noqa _ = mocker.patch( 'werewolf.game_master.default_game_master.create_chat_openai_model', diff --git a/werewolf/game.py b/werewolf/game.py index 0be8af3..4c3f853 100644 --- a/werewolf/game.py +++ b/werewolf/game.py @@ -3,7 +3,7 @@ from typing import Iterable import autogen -from langchain_openai import ChatOpenAI +from langchain_core.language_models import BaseChatModel from .const import DEFAULT_MODEL, EGameMaster from .config import GameConfig @@ -22,7 +22,7 @@ def game( include_human: bool = False, open_game: bool = False, config_list=[{'model': DEFAULT_MODEL}], - llm: ChatOpenAI | str | None = None, + llm: BaseChatModel | str | None = None, printer: str = 'click.echo', log_file: str = 'werewolf.log', logger: logging.Logger = logging.getLogger(__name__), diff --git a/werewolf/game_master/default_game_master.py b/werewolf/game_master/default_game_master.py index efd26ee..c65810f 100644 --- a/werewolf/game_master/default_game_master.py +++ b/werewolf/game_master/default_game_master.py @@ -12,10 +12,10 @@ EnumOutputParser, RetryWithErrorOutputParser, ) +from langchain_core.language_models import BaseChatModel from langchain_core.prompt_values import StringPromptValue from langchain_core.runnables import Runnable, RunnableLambda from langchain.output_parsers.retry import NAIVE_RETRY_WITH_ERROR_PROMPT -from langchain_openai import ChatOpenAI from ..alias import WhoToVote from .base import BaseGameMaster @@ -190,7 +190,7 @@ def _clean_name( self, name: str, question: str, - llm: ChatOpenAI | str | None = None, + llm: BaseChatModel | str | None = None, max_retry: int = 5, ) -> str: # noqa logging.debug(f'Clean name: {name}') From 8cab405c60b3603c7c94fc4f660cc583712a653d Mon Sep 17 00:00:00 2001 From: hmasdev Date: Sat, 7 Sep 2024 19:11:10 +0900 Subject: [PATCH 2/7] chore: add Groq and Gemini and update dependencies --- pyproject.toml | 4 +- tests/game_master/test_default_game_master.py | 15 ++- tests/test_chat_models.py | 113 ++++++++++++++++++ werewolf/chat_models.py | 56 +++++++++ werewolf/const.py | 36 ++++++ werewolf/game.py | 5 +- werewolf/game_master/default_game_master.py | 4 +- werewolf/main.py | 30 +++-- 8 files changed, 245 insertions(+), 18 deletions(-) create mode 100644 tests/test_chat_models.py create mode 100644 werewolf/chat_models.py diff --git a/pyproject.toml b/pyproject.toml index eb21661..78ba2ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,8 +6,10 @@ requires-python = ">=3.10" dependencies = [ "click>=8.1.7", "langchain>=0.2.0", + "langchain-groq>=0.1.9", + "langchain-google-genai>=1.0.10", "langchain-openai", - "pyautogen==0.2.16", + "pyautogen[gemini,groq]>=0.2.32", "python-dotenv>=1.0.1", ] authors = [{ name = "hmasdev" }] diff --git a/tests/game_master/test_default_game_master.py b/tests/game_master/test_default_game_master.py index 7924192..babcf4d 100644 --- a/tests/game_master/test_default_game_master.py +++ b/tests/game_master/test_default_game_master.py @@ -2,6 +2,7 @@ from dataclasses import asdict import os import autogen +from dotenv import load_dotenv from flaky import flaky from langchain_core.language_models import BaseChatModel import pytest @@ -15,6 +16,8 @@ ) from werewolf.game_master.default_game_master import DefaultGameMaster +load_dotenv() + @pytest.fixture def game_config() -> GameConfig: @@ -159,8 +162,8 @@ def test_DefaultGameMaster__clean_name( llm_output = expected llm_mock = mocker.MagicMock(spec=BaseChatModel) llm_mock.invoke.return_value = namedtuple('BaseMessage', ['content'])(llm_output) # noqa - create_chat_openai_model_mock = mocker.patch( - 'werewolf.game_master.default_game_master.create_chat_openai_model', + create_chat_model_mock = mocker.patch( + 'werewolf.game_master.default_game_master.create_chat_model', return_value=llm_mock, autospec=True, ) @@ -174,7 +177,7 @@ def test_DefaultGameMaster__clean_name( actual = master._clean_name(input_name, question='Who do you think should be excluded from the game?') # noqa # assert assert actual == expected - create_chat_openai_model_mock.assert_called_once() + create_chat_model_mock.assert_called_once() @pytest.mark.parametrize( @@ -193,7 +196,7 @@ def test_DefaultGameMaster__clean_name_all_fail( llm_mock = mocker.MagicMock(spec=BaseChatModel) llm_mock.invoke.return_value = namedtuple('BaseMessage', ['content'])(llm_output) # noqa _ = mocker.patch( - 'werewolf.game_master.default_game_master.create_chat_openai_model', + 'werewolf.game_master.default_game_master.create_chat_model', return_value=llm_mock, autospec=True, ) @@ -229,7 +232,7 @@ def test_DefaultGameMaster__clean_name_n_fails( for _ in range(n_fails) ] + [namedtuple('BaseMessage', ['content'])(llm_output)] _ = mocker.patch( - 'werewolf.game_master.default_game_master.create_chat_openai_model', + 'werewolf.game_master.default_game_master.create_chat_model', return_value=llm_mock, autospec=True, ) @@ -394,7 +397,7 @@ def test_DefaultGameMaster_ask_to_vote_without_last_message_content( llm_mock = mocker.MagicMock(spec=BaseChatModel) llm_mock.invoke.return_value = namedtuple('BaseMessage', ['content'])(llm_output) # noqa _ = mocker.patch( - 'werewolf.game_master.default_game_master.create_chat_openai_model', + 'werewolf.game_master.default_game_master.create_chat_model', return_value=llm_mock, autospec=True, ) diff --git a/tests/test_chat_models.py b/tests/test_chat_models.py new file mode 100644 index 0000000..d08e37a --- /dev/null +++ b/tests/test_chat_models.py @@ -0,0 +1,113 @@ +import os +from dotenv import load_dotenv +import pytest +from pytest_mock import MockerFixture +from langchain_core.language_models import BaseChatModel +from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_groq import ChatGroq +from langchain_openai import ChatOpenAI +from werewolf.chat_models import create_chat_model, _service2cls +from werewolf.const import MODEL_SERVICE_MAP + +load_dotenv() + +name2cls: dict[str, type[BaseChatModel]] = { + k: _service2cls[v] + for k, v in MODEL_SERVICE_MAP.items() +} + + +@pytest.mark.parametrize( + 'llm, expected', + [(k, v) for _, (k, v) in enumerate(name2cls.items())] +) +def test_create_chat_model_wo_seed( + llm: str, + expected: type[BaseChatModel], + mocker: MockerFixture, +) -> None: + mocker.patch(f'werewolf.chat_models.{expected.__name__}', autospec=True) # noqa + assert isinstance(create_chat_model(llm), expected) + + +@pytest.mark.parametrize( + 'llm, seed, expected', + [(k, i, v) for i, (k, v) in enumerate(name2cls.items())] +) +def test_create_chat_model_w_seed( + llm: str, + seed: int, + expected: type[BaseChatModel], + mocker: MockerFixture, +) -> None: + cls_mock = mocker.patch(f'werewolf.chat_models.{expected.__name__}', autospec=True) # noqa + assert isinstance(create_chat_model(llm, seed), expected) + + # TODO: fix the following assertion + # cls_mock.assert_called_once_with(model=llm, seed=seed) + + +def test_create_chat_model_w_invalid_llm() -> None: + with pytest.raises(ValueError): + create_chat_model('invalid') + + +@pytest.mark.integration +@pytest.mark.skipif( + os.getenv("OPENAI_API_KEY") is None, + reason="OPENAI_API is not set.", +) +@pytest.mark.parametrize( + 'llm', + [ + 'gpt-4o-mini', + 'gpt-4', + 'gpt-4-turbo', + 'gpt-4o', + 'gpt-3.5-turbo', + ] +) +def test_create_chat_model_for_ChatOpenAI_integration(llm: str) -> None: + assert isinstance(create_chat_model(llm), ChatOpenAI) + + +@pytest.mark.integration +@pytest.mark.skipif( + os.getenv("GROQ_API_KEY") is None, + reason="GROQ_API_KEY is not set.", +) +@pytest.mark.parametrize( + 'llm', + [ + 'llama3-groq-70b-8192-tool-use-preview', + 'llama3-groq-8b-8192-tool-use-preview', + 'llama-3.1-70b-versatile', + 'llama-3.1-8b-instant', + 'llama-guard-3-8b', + 'llava-v1.5-7b-4096-preview', + 'llama3-70b-8192', + 'llama3-8b-8192', + 'mixtral-8x7b-32768', + 'gemma2-9b-it', + 'gemma2-7b-it', + ] +) +def test_create_chat_model_for_ChatGroq_integration(llm: str) -> None: + assert isinstance(create_chat_model(llm), ChatGroq) + + +@pytest.mark.integration +@pytest.mark.skipif( + os.getenv("GOOGLE_API_KEY") is None, + reason="GOOGLE_API_KEY is not set.", +) +@pytest.mark.parametrize( + 'llm', + [ + 'gemini-1.5-flash', + 'gemini-pro-vision', + 'gemini-pro', + ] +) +def test_create_chat_model_for_ChatGoogleGenerativeAI_integration(llm: str) -> None: # noqa + assert isinstance(create_chat_model(llm), ChatGoogleGenerativeAI) diff --git a/werewolf/chat_models.py b/werewolf/chat_models.py new file mode 100644 index 0000000..3936fcc --- /dev/null +++ b/werewolf/chat_models.py @@ -0,0 +1,56 @@ +from functools import lru_cache +from logging import getLogger, Logger +from langchain_core.language_models import BaseChatModel +# from langchain_community.chat_models import ChatPerplexity +from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_groq import ChatGroq +from langchain_openai import ChatOpenAI + +from .const import DEFAULT_MODEL, EChatService, MODEL_SERVICE_MAP + + +_service2cls: dict[EChatService, type[BaseChatModel]] = { + EChatService.OpenAI: ChatOpenAI, + EChatService.Google: ChatGoogleGenerativeAI, + EChatService.Groq: ChatGroq, +} + + +@lru_cache(maxsize=None) +def create_chat_model( + llm: BaseChatModel | str = DEFAULT_MODEL, + seed: int | None = None, + logger: Logger = getLogger(__name__), + **kwargs, +) -> BaseChatModel: + """Create a ChatModel instance. + + Args: + llm (BaseChatModel | str, optional): ChatModel instance or model name. Defaults to DEFAULT_MODEL. + seed (int, optional): Random seed. Defaults to None. + logger (Logger, optional): Logger. Defaults to getLogger(__name__). + + Raises: + ValueError: Unknown model name + + Returns: + BaseChatModel: ChatModel instance + + Note: + seed is used only when llm is a str. + The same parameters return the same instance. + """ # noqa + llm = llm or DEFAULT_MODEL + if isinstance(llm, str): + try: + if seed is not None: + return _service2cls[MODEL_SERVICE_MAP[llm]](model=llm, seed=seed, **kwargs) # type: ignore # noqa + else: + return _service2cls[MODEL_SERVICE_MAP[llm]](model=llm, **kwargs) # type: ignore # noqa + except TypeError: + logger.warning(f'{llm} does not support seed.') + return _service2cls[MODEL_SERVICE_MAP[llm]](model=llm, **kwargs) # type: ignore # noqa + except KeyError: + raise ValueError(f'Unknown model name: {llm}') + else: + return llm diff --git a/werewolf/const.py b/werewolf/const.py index 078ffbb..eb1ccea 100644 --- a/werewolf/const.py +++ b/werewolf/const.py @@ -43,3 +43,39 @@ class EStatus(Enum): class EResult(Enum): VillagersWin: str = 'VillagersWin' WerewolvesWin: str = 'WerewolvesWin' + + +class EChatService(Enum): + OpenAI: str = 'openai' + Google: str = 'google' + Groq: str = 'groq' + + +MODEL_SERVICE_MAP: dict[str, EChatService] = { + 'gpt-3.5-turbo': EChatService.OpenAI, + 'gpt-4': EChatService.OpenAI, + 'gpt-4-turbo': EChatService.OpenAI, + 'gpt-4o': EChatService.OpenAI, + 'gpt-4o-mini': EChatService.OpenAI, + 'gemini-1.5-flash': EChatService.Google, + "gemini-pro-vision": EChatService.Google, + 'gemini-pro': EChatService.Google, + 'mixtral-8x7b-32768': EChatService.Groq, + 'gemma2-9b-it': EChatService.Groq, + 'gemma2-7b-it': EChatService.Groq, + 'llama3-groq-70b-8192-tool-use-preview': EChatService.Groq, + 'llama3-groq-8b-8192-tool-use-preview': EChatService.Groq, + 'llama-3.1-70b-versatile': EChatService.Groq, + 'llama-3.1-8b-instant': EChatService.Groq, + 'llama-guard-3-8b': EChatService.Groq, + 'llava-v1.5-7b-4096-preview': EChatService.Groq, + 'llama3-70b-8192': EChatService.Groq, + 'llama3-8b-8192': EChatService.Groq, + 'mixtral-8x7b-32768': EChatService.Groq, +} +VALID_MODELS: tuple[str, ...] = tuple(MODEL_SERVICE_MAP.keys()) +SERVICE_APIKEY_ENVVAR_MAP: dict[EChatService, str] = { + EChatService.Google: 'GOOGLE_API_KEY', + EChatService.Groq: 'GROQ_API_KEY', + EChatService.OpenAI: 'OPENAI_API_KEY', +} diff --git a/werewolf/game.py b/werewolf/game.py index 4c3f853..5b5fef4 100644 --- a/werewolf/game.py +++ b/werewolf/game.py @@ -5,10 +5,10 @@ import autogen from langchain_core.language_models import BaseChatModel +from .chat_models import create_chat_model from .const import DEFAULT_MODEL, EGameMaster from .config import GameConfig from .game_master.base import BaseGameMaster -from .utils.openai import create_chat_openai_model from .utils.printer import create_print_func @@ -23,13 +23,14 @@ def game( open_game: bool = False, config_list=[{'model': DEFAULT_MODEL}], llm: BaseChatModel | str | None = None, + seed: int | None = None, printer: str = 'click.echo', log_file: str = 'werewolf.log', logger: logging.Logger = logging.getLogger(__name__), ): # preparation print_func = create_print_func(printer) - llm = create_chat_openai_model(llm) + llm = create_chat_model(llm, seed=seed) master = BaseGameMaster.instantiate( EGameMaster.Default, # TODO groupchat=autogen.GroupChat(agents=[], messages=[]), diff --git a/werewolf/game_master/default_game_master.py b/werewolf/game_master/default_game_master.py index c65810f..58b4567 100644 --- a/werewolf/game_master/default_game_master.py +++ b/werewolf/game_master/default_game_master.py @@ -20,6 +20,7 @@ from ..alias import WhoToVote from .base import BaseGameMaster from ..base import IWerewolfPlayer +from ..chat_models import create_chat_model from ..const import ( EResult, ERole, @@ -31,7 +32,6 @@ from ..config import GameConfig from ..game_player.base import BaseWerewolfPlayer from ..utils.autogen_utils import just1turn -from ..utils.openai import create_chat_openai_model from ..utils.utils import ( consecutive_string_generator, instant_decoration, @@ -195,7 +195,7 @@ def _clean_name( ) -> str: # noqa logging.debug(f'Clean name: {name}') base_llm_chain: Runnable[str, str] = ( - create_chat_openai_model(llm) + create_chat_model(llm) | RunnableLambda(attrgetter('content')) ) chain = RetryWithErrorOutputParser.from_llm( diff --git a/werewolf/main.py b/werewolf/main.py index 6a24ec6..6e72ce9 100644 --- a/werewolf/main.py +++ b/werewolf/main.py @@ -5,7 +5,14 @@ import click from dotenv import load_dotenv -from .const import DEFAULT_MODEL, ESpeakerSelectionMethod +from .const import ( + DEFAULT_MODEL, + EChatService, + ESpeakerSelectionMethod, + MODEL_SERVICE_MAP, + SERVICE_APIKEY_ENVVAR_MAP, + VALID_MODELS, +) from .game import game from .utils.printer import KEYS_FOR_PRINTER, create_print_func @@ -20,10 +27,10 @@ @click.option('-h', '--include-human', is_flag=True, help='Whether to include human or not.') # noqa @click.option('-o', '--open-game', is_flag=True, help='Whether to open game or not.') # noqa @click.option('-l', '--log', default=None, help='The log file name. Default is werewolf%Y%m%d%H%M%S.log') # noqa -@click.option('-m', '--model', default=DEFAULT_MODEL, help=f'The model name. Default is {DEFAULT_MODEL}.') # noqa +@click.option('-m', '--model', default=DEFAULT_MODEL, help=f'The model name. Default is {DEFAULT_MODEL}. The valid model is as follows: {VALID_MODELS} or the repository id of huggingface.') # noqa @click.option('-p', '--printer', default='click.echo', help=f'The printer name. The valid values is in {KEYS_FOR_PRINTER}. Default is click.echo.') # noqa -@click.option('--sub-model', default=DEFAULT_MODEL, help=f'The sub-model name. Default is {DEFAULT_MODEL}.') # noqa -@click.option('--seed', default=None, help='The random seed.') # noqa +@click.option('--sub-model', default=DEFAULT_MODEL, help=f'The sub-model name. Default is {DEFAULT_MODEL}. The valid model is as follows: {VALID_MODELS} or the repository id of huggingface') # noqa +@click.option('--seed', default=-1, help='The random seed. Default to -1. NOTE: a negative integer means "not specify the seed"') # noqa @click.option('--log-level', default='WARNING', help='The log level, DEBUG, INFO, WARNING, ERROR or CRITICAL. Default is WARNING.') # noqa @click.option('--debug', is_flag=True, help='Whether to show debug logs or not.') # noqa def main( @@ -39,11 +46,11 @@ def main( model: str, printer: str, sub_model: str, - seed: int | None, + seed: int, log_level: str, debug: bool, ): - if seed is not None: + if seed >= 0: import random random.seed(seed) @@ -55,7 +62,15 @@ def main( log = f'werewolf{dt.now().strftime("%Y%m%d%H%M%S")}.log' if log is None else log # noqa logging.basicConfig(level=logging.DEBUG if debug else getattr(logging, log_level.upper(), 'WARNING')) # type: ignore # noqa - config_list = [{'model': model}] + if (service := MODEL_SERVICE_MAP[model]) == EChatService.OpenAI: + config_list = [{'model': model}] + else: + config_list = [{ + 'model': model, + 'api_type': service.value, + # FIXME: this is not good because the raw token is on memory + 'api_key': os.getenv(SERVICE_APIKEY_ENVVAR_MAP[service]), + }] result = game( n_players=n_players, @@ -70,6 +85,7 @@ def main( log_file=log, config_list=config_list, llm=sub_model, + seed=seed, ) printer_func('================================ Game Result ================================') # noqa From 463685c8ed64f977917f10674dc63d614b79dcbe Mon Sep 17 00:00:00 2001 From: hmasdev Date: Sat, 7 Sep 2024 20:34:28 +0900 Subject: [PATCH 3/7] add deprecation decorator to create_chat_openai_model --- werewolf/utils/openai.py | 2 ++ werewolf/utils/utils.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/werewolf/utils/openai.py b/werewolf/utils/openai.py index d9e8ef5..e696b11 100644 --- a/werewolf/utils/openai.py +++ b/werewolf/utils/openai.py @@ -1,8 +1,10 @@ from functools import lru_cache from langchain_openai import ChatOpenAI from ..const import DEFAULT_MODEL +from ..utils.utils import deprecate +@deprecate(msg='Use werewolf.chat_models.create_chat_model instead.') @lru_cache(maxsize=None) def create_chat_openai_model( llm: ChatOpenAI | str | None = None, diff --git a/werewolf/utils/utils.py b/werewolf/utils/utils.py index 6408922..a78bcf6 100644 --- a/werewolf/utils/utils.py +++ b/werewolf/utils/utils.py @@ -1,4 +1,6 @@ from contextlib import contextmanager +from functools import partial +from logging import Logger, getLogger from typing import Callable, Generator, Iterable, TypeVar InputType = TypeVar('InputType') @@ -47,3 +49,36 @@ def consecutive_string_generator( while True: yield f'{prefix}{idx}' idx += step + + +def deprecate( + func: Callable[[InputType], OutputType] | None = None, + *, + msg: str = '', + logger: Logger = getLogger(__name__), +) -> ( + Callable[[InputType], OutputType] + | Callable[[Callable[[InputType], OutputType]], Callable[[InputType], OutputType]] +): + """Decorator to deprecate a function. + + Args: + func (Callable[[InputType], OutputType] | None, optional): function to deprecate. Defaults to None. + msg (str, optional): message to show. Defaults to ''. + logger (Logger, optional): logger. Defaults to getLogger(__name__). + + Returns: + Callable[[InputType], OutputType] | Callable[[Callable[[InputType], OutputType]], Callable[[InputType], OutputType]]: decorated function or decorator + """ # noqa + + if func is None: + return partial(deprecate, msg=msg, logger=logger) + + def wrapper(*args, **kwargs): + if hasattr(func, '__name__'): + logger.warning(' '.join([f'{func.__name__} is planned to be deprecated.', msg])) # noqa + else: + logger.warning(' '.join([f'{func} is planned to be deprecated.', msg])) # noqa + return func(*args, **kwargs) + + return wrapper From 1087126a00c442933c2c312863e8ed4caeba4b98 Mon Sep 17 00:00:00 2001 From: hmasdev Date: Sat, 7 Sep 2024 20:38:49 +0900 Subject: [PATCH 4/7] Update README.md with API key instructions for OpenAI, Groq, and Gemini --- README.md.j2 | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/README.md.j2 b/README.md.j2 index eaf9838..e24b544 100644 --- a/README.md.j2 +++ b/README.md.j2 @@ -10,8 +10,13 @@ ## Requirements -- OpenAI API Key - - [https://platform.openai.com/api-keys](https://platform.openai.com/api-keys) +- Get your API key + - OpenAI API Key + - [https://platform.openai.com/api-keys](https://platform.openai.com/api-keys) + - Groq API Key + - [https://console.groq.com/keys](https://console.groq.com/keys) + - Gemini API Key + - [https://aistudio.google.com/app/apikey](https://aistudio.google.com/app/apikey) - (optional) `docker compose` - (optional) python >= 3.10 @@ -23,10 +28,13 @@ Note that either of `docker compose` or `python` is required. ### Preparation 1. Create `.env` file -2. Set `OPENAI_API_KEY`: +2. Set `OPENAI_API_KEY`, `GROQ_API_KEY` or `GOOGLE_API_KEY` in the `.env` file as follows: ```text - OPENAI_API_KEY=HERE_IS_YOUR_OPENAI_API_KEY + OPENAI_API_KEY=HERE_IS_YOUR_API_KEY + GROQ_API_KEY=HERE_IS_YOUR_API_KEY + GOOGLE_API_KEY=HERE_IS_YOUR_API_KEY + ``` 3. If you don't use `docker` but `python` in your machine, create a virtual environment and install libraries manually: From d23172cea708e1ff6d887ece7bc3000990f2a71e Mon Sep 17 00:00:00 2001 From: hmasdev Date: Sat, 7 Sep 2024 20:50:37 +0900 Subject: [PATCH 5/7] Add fixture to load API keys from environment variables --- tests/conftest.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..11ee18c --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,12 @@ +import os +from dotenv import load_dotenv +import pytest +from werewolf.const import SERVICE_APIKEY_ENVVAR_MAP + + +@pytest.fixture(scope='session', autouse=True) +def api_keys(): + for envvar in SERVICE_APIKEY_ENVVAR_MAP.values(): + os.environ[envvar] = os.getenv(envvar) + load_dotenv() + yield From 178e4b1c18800357e5fbf08d36d525109a0b6e0a Mon Sep 17 00:00:00 2001 From: hmasdev Date: Sat, 7 Sep 2024 20:57:03 +0900 Subject: [PATCH 6/7] apply flake8 and mypy --- tests/conftest.py | 2 +- tests/utils/test_openai.py | 22 +++++++++++----------- werewolf/const.py | 1 - werewolf/main.py | 2 +- werewolf/utils/openai.py | 2 +- werewolf/utils/utils.py | 4 ++-- 6 files changed, 16 insertions(+), 17 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 11ee18c..f1c6242 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,6 @@ @pytest.fixture(scope='session', autouse=True) def api_keys(): for envvar in SERVICE_APIKEY_ENVVAR_MAP.values(): - os.environ[envvar] = os.getenv(envvar) + os.environ[envvar] = os.getenv(envvar) or '' # type: ignore load_dotenv() yield diff --git a/tests/utils/test_openai.py b/tests/utils/test_openai.py index 2d379ef..d2dc9a8 100644 --- a/tests/utils/test_openai.py +++ b/tests/utils/test_openai.py @@ -9,9 +9,9 @@ load_dotenv() -def test_create_chat_openai_model_with_none(mocker: MockerFixture): +def test_create_chat_openai_model_with_none(mocker: MockerFixture) -> None: # noqa coai_mock = mocker.patch("werewolf.utils.openai.ChatOpenAI") - create_chat_openai_model() + create_chat_openai_model() # type: ignore coai_mock.assert_called_once_with(model=DEFAULT_MODEL, seed=None) @@ -19,15 +19,15 @@ def test_create_chat_openai_model_with_none(mocker: MockerFixture): def test_create_chat_openai_model_with_str( model_name: str, mocker: MockerFixture, -): +) -> None: coai_mock = mocker.patch("werewolf.utils.openai.ChatOpenAI") - create_chat_openai_model(model_name) + create_chat_openai_model(model_name) # type: ignore coai_mock.assert_called_once_with(model=model_name, seed=None) -def test_create_chat_openai_model_with_instance(mocker: MockerFixture): +def test_create_chat_openai_model_with_instance(mocker: MockerFixture) -> None: # noqa llm = mocker.MagicMock(spec=ChatOpenAI) - actual = create_chat_openai_model(llm) + actual: ChatOpenAI = create_chat_openai_model(llm) assert actual is llm @@ -43,13 +43,13 @@ def test_create_chat_openai_model_return_same_instance_for_same_input( llm: str, seed: int, mocker: MockerFixture, -): +) -> None: ChatOpenAI_mock = mocker.patch( "werewolf.utils.openai.ChatOpenAI", return_value=mocker.MagicMock(spec=ChatOpenAI), ) - actual1 = create_chat_openai_model(llm, seed) - actual2 = create_chat_openai_model(llm, seed) + actual1: ChatOpenAI = create_chat_openai_model(llm, seed) # type: ignore + actual2: ChatOpenAI = create_chat_openai_model(llm, seed) # type: ignore assert actual1 is actual2 ChatOpenAI_mock.assert_called_once_with(model=llm, seed=seed) @@ -66,6 +66,6 @@ def test_create_chat_openai_model_return_same_instance_for_same_input( None, ], ) -def test_create_chat_openai_model_with_real_instance(llm: str | None): - actual = create_chat_openai_model(llm) +def test_create_chat_openai_model_with_real_instance(llm: str | None) -> None: + actual: ChatOpenAI = create_chat_openai_model(llm) # type: ignore assert isinstance(actual, ChatOpenAI) diff --git a/werewolf/const.py b/werewolf/const.py index eb1ccea..71abb36 100644 --- a/werewolf/const.py +++ b/werewolf/const.py @@ -60,7 +60,6 @@ class EChatService(Enum): 'gemini-1.5-flash': EChatService.Google, "gemini-pro-vision": EChatService.Google, 'gemini-pro': EChatService.Google, - 'mixtral-8x7b-32768': EChatService.Groq, 'gemma2-9b-it': EChatService.Groq, 'gemma2-7b-it': EChatService.Groq, 'llama3-groq-70b-8192-tool-use-preview': EChatService.Groq, diff --git a/werewolf/main.py b/werewolf/main.py index 6e72ce9..5f1a677 100644 --- a/werewolf/main.py +++ b/werewolf/main.py @@ -69,7 +69,7 @@ def main( 'model': model, 'api_type': service.value, # FIXME: this is not good because the raw token is on memory - 'api_key': os.getenv(SERVICE_APIKEY_ENVVAR_MAP[service]), + 'api_key': os.getenv(SERVICE_APIKEY_ENVVAR_MAP[service]), # type: ignore # noqa }] result = game( diff --git a/werewolf/utils/openai.py b/werewolf/utils/openai.py index e696b11..b2bf884 100644 --- a/werewolf/utils/openai.py +++ b/werewolf/utils/openai.py @@ -4,7 +4,7 @@ from ..utils.utils import deprecate -@deprecate(msg='Use werewolf.chat_models.create_chat_model instead.') +@deprecate(msg='Use werewolf.chat_models.create_chat_model instead.') # type: ignore # noqa @lru_cache(maxsize=None) def create_chat_openai_model( llm: ChatOpenAI | str | None = None, diff --git a/werewolf/utils/utils.py b/werewolf/utils/utils.py index a78bcf6..60e5ba8 100644 --- a/werewolf/utils/utils.py +++ b/werewolf/utils/utils.py @@ -58,7 +58,7 @@ def deprecate( logger: Logger = getLogger(__name__), ) -> ( Callable[[InputType], OutputType] - | Callable[[Callable[[InputType], OutputType]], Callable[[InputType], OutputType]] + | Callable[[Callable[[InputType], OutputType]], Callable[[InputType], OutputType]] # noqa ): """Decorator to deprecate a function. @@ -72,7 +72,7 @@ def deprecate( """ # noqa if func is None: - return partial(deprecate, msg=msg, logger=logger) + return partial(deprecate, msg=msg, logger=logger) # type: ignore def wrapper(*args, **kwargs): if hasattr(func, '__name__'): From 1b5f1f81ee7643905d45f14d8ab4b2cc27021a18 Mon Sep 17 00:00:00 2001 From: hmasdev Date: Sat, 7 Sep 2024 21:07:30 +0900 Subject: [PATCH 7/7] Update the fixture to load API keys from environment variables --- tests/conftest.py | 10 +++++++--- tests/test_chat_models.py | 8 ++++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index f1c6242..9aafe10 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,16 @@ import os +from typing import Generator from dotenv import load_dotenv import pytest from werewolf.const import SERVICE_APIKEY_ENVVAR_MAP -@pytest.fixture(scope='session', autouse=True) -def api_keys(): +@pytest.fixture(scope='function') +def api_keys(dummy: str = 'dummy') -> Generator[None, None, None]: for envvar in SERVICE_APIKEY_ENVVAR_MAP.values(): - os.environ[envvar] = os.getenv(envvar) or '' # type: ignore + os.environ[envvar] = os.getenv(envvar) or dummy # type: ignore load_dotenv() yield + for envvar in SERVICE_APIKEY_ENVVAR_MAP.values(): + if os.getenv(envvar) == dummy: + del os.environ[envvar] diff --git a/tests/test_chat_models.py b/tests/test_chat_models.py index d08e37a..0fdfe7b 100644 --- a/tests/test_chat_models.py +++ b/tests/test_chat_models.py @@ -17,6 +17,7 @@ } +@pytest.mark.usefixtures('api_keys') @pytest.mark.parametrize( 'llm, expected', [(k, v) for _, (k, v) in enumerate(name2cls.items())] @@ -27,9 +28,11 @@ def test_create_chat_model_wo_seed( mocker: MockerFixture, ) -> None: mocker.patch(f'werewolf.chat_models.{expected.__name__}', autospec=True) # noqa - assert isinstance(create_chat_model(llm), expected) + # assert isinstance(create_chat_model(llm), expected) + assert create_chat_model(llm).__class__.__name__ == expected.__name__ +@pytest.mark.usefixtures('api_keys') @pytest.mark.parametrize( 'llm, seed, expected', [(k, i, v) for i, (k, v) in enumerate(name2cls.items())] @@ -41,7 +44,8 @@ def test_create_chat_model_w_seed( mocker: MockerFixture, ) -> None: cls_mock = mocker.patch(f'werewolf.chat_models.{expected.__name__}', autospec=True) # noqa - assert isinstance(create_chat_model(llm, seed), expected) + # assert isinstance(create_chat_model(llm, seed), expected) + assert create_chat_model(llm, seed).__class__.__name__ == expected.__name__ # TODO: fix the following assertion # cls_mock.assert_called_once_with(model=llm, seed=seed)