diff --git a/tests/utils/test_openai.py b/tests/utils/test_openai.py index 703da30..2d379ef 100644 --- a/tests/utils/test_openai.py +++ b/tests/utils/test_openai.py @@ -12,7 +12,7 @@ def test_create_chat_openai_model_with_none(mocker: MockerFixture): coai_mock = mocker.patch("werewolf.utils.openai.ChatOpenAI") create_chat_openai_model() - coai_mock.assert_called_once_with(model=DEFAULT_MODEL) + coai_mock.assert_called_once_with(model=DEFAULT_MODEL, seed=None) @pytest.mark.parametrize("model_name", ["gpt-3.5-turbo", "gpt-4o-mini"]) @@ -22,7 +22,7 @@ def test_create_chat_openai_model_with_str( ): coai_mock = mocker.patch("werewolf.utils.openai.ChatOpenAI") create_chat_openai_model(model_name) - coai_mock.assert_called_once_with(model=model_name) + coai_mock.assert_called_once_with(model=model_name, seed=None) def test_create_chat_openai_model_with_instance(mocker: MockerFixture): @@ -31,6 +31,29 @@ def test_create_chat_openai_model_with_instance(mocker: MockerFixture): assert actual is llm +@pytest.mark.parametrize( + 'llm, seed', + [ + ('gpt-3.5-turbo', 123), + ('gpt-4o-mini', 456), + ('gpt-4o-mini', None), + ], +) +def test_create_chat_openai_model_return_same_instance_for_same_input( + llm: str, + seed: int, + mocker: MockerFixture, +): + ChatOpenAI_mock = mocker.patch( + "werewolf.utils.openai.ChatOpenAI", + return_value=mocker.MagicMock(spec=ChatOpenAI), + ) + actual1 = create_chat_openai_model(llm, seed) + actual2 = create_chat_openai_model(llm, seed) + assert actual1 is actual2 + ChatOpenAI_mock.assert_called_once_with(model=llm, seed=seed) + + @pytest.mark.integration @pytest.mark.skipif( os.getenv("OPENAI_API_KEY") is None, diff --git a/werewolf/main.py b/werewolf/main.py index c68f656..6a24ec6 100644 --- a/werewolf/main.py +++ b/werewolf/main.py @@ -23,6 +23,7 @@ @click.option('-m', '--model', default=DEFAULT_MODEL, help=f'The model name. Default is {DEFAULT_MODEL}.') # noqa @click.option('-p', '--printer', default='click.echo', help=f'The printer name. The valid values is in {KEYS_FOR_PRINTER}. Default is click.echo.') # noqa @click.option('--sub-model', default=DEFAULT_MODEL, help=f'The sub-model name. Default is {DEFAULT_MODEL}.') # noqa +@click.option('--seed', default=None, help='The random seed.') # noqa @click.option('--log-level', default='WARNING', help='The log level, DEBUG, INFO, WARNING, ERROR or CRITICAL. Default is WARNING.') # noqa @click.option('--debug', is_flag=True, help='Whether to show debug logs or not.') # noqa def main( @@ -38,9 +39,14 @@ def main( model: str, printer: str, sub_model: str, + seed: int | None, log_level: str, debug: bool, ): + if seed is not None: + import random + random.seed(seed) + load_dotenv(), if os.environ.get('OPENAI_API_KEY') is None: raise ValueError('You must set OPENAI_API_KEY in your environment variables or .env file.') # noqa diff --git a/werewolf/utils/openai.py b/werewolf/utils/openai.py index e894eeb..d9e8ef5 100644 --- a/werewolf/utils/openai.py +++ b/werewolf/utils/openai.py @@ -1,22 +1,26 @@ +from functools import lru_cache from langchain_openai import ChatOpenAI from ..const import DEFAULT_MODEL +@lru_cache(maxsize=None) def create_chat_openai_model( llm: ChatOpenAI | str | None = None, + seed: int | None = None, ) -> ChatOpenAI: """Create a ChatOpenAI instance. Args: llm (ChatOpenAI | str | None, optional): ChatOpenAI instance or model name. Defaults to None. + seed (int, optional): Random seed. Defaults to None. Returns: ChatOpenAI: ChatOpenAI instance Note: - + seed is used only when llm is a str or None. """ # noqa if isinstance(llm, str): - return ChatOpenAI(model=llm) + return ChatOpenAI(model=llm, seed=seed) else: - return llm or ChatOpenAI(model=DEFAULT_MODEL) + return llm or ChatOpenAI(model=DEFAULT_MODEL, seed=seed)