Skip to content

Commit

Permalink
chore: Update create_chat_openai_model to include a seed parameter
Browse files Browse the repository at this point in the history
The `create_chat_openai_model` function in the `openai.py` module has been updated to include a new `seed` parameter. This parameter allows for setting a random seed when creating a `ChatOpenAI` instance. This change improves the control and reproducibility of the generated chat responses.
  • Loading branch information
hmasdev committed Sep 5, 2024
1 parent 93e8036 commit 368f652
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
27 changes: 25 additions & 2 deletions tests/utils/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
def test_create_chat_openai_model_with_none(mocker: MockerFixture):
coai_mock = mocker.patch("werewolf.utils.openai.ChatOpenAI")
create_chat_openai_model()
coai_mock.assert_called_once_with(model=DEFAULT_MODEL)
coai_mock.assert_called_once_with(model=DEFAULT_MODEL, seed=None)


@pytest.mark.parametrize("model_name", ["gpt-3.5-turbo", "gpt-4o-mini"])
Expand All @@ -22,7 +22,7 @@ def test_create_chat_openai_model_with_str(
):
coai_mock = mocker.patch("werewolf.utils.openai.ChatOpenAI")
create_chat_openai_model(model_name)
coai_mock.assert_called_once_with(model=model_name)
coai_mock.assert_called_once_with(model=model_name, seed=None)


def test_create_chat_openai_model_with_instance(mocker: MockerFixture):
Expand All @@ -31,6 +31,29 @@ def test_create_chat_openai_model_with_instance(mocker: MockerFixture):
assert actual is llm


@pytest.mark.parametrize(
'llm, seed',
[
('gpt-3.5-turbo', 123),
('gpt-4o-mini', 456),
('gpt-4o-mini', None),
],
)
def test_create_chat_openai_model_return_same_instance_for_same_input(
llm: str,
seed: int,
mocker: MockerFixture,
):
ChatOpenAI_mock = mocker.patch(
"werewolf.utils.openai.ChatOpenAI",
return_value=mocker.MagicMock(spec=ChatOpenAI),
)
actual1 = create_chat_openai_model(llm, seed)
actual2 = create_chat_openai_model(llm, seed)
assert actual1 is actual2
ChatOpenAI_mock.assert_called_once_with(model=llm, seed=seed)


@pytest.mark.integration
@pytest.mark.skipif(
os.getenv("OPENAI_API_KEY") is None,
Expand Down
6 changes: 6 additions & 0 deletions werewolf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
@click.option('-m', '--model', default=DEFAULT_MODEL, help=f'The model name. Default is {DEFAULT_MODEL}.') # noqa
@click.option('-p', '--printer', default='click.echo', help=f'The printer name. The valid values is in {KEYS_FOR_PRINTER}. Default is click.echo.') # noqa
@click.option('--sub-model', default=DEFAULT_MODEL, help=f'The sub-model name. Default is {DEFAULT_MODEL}.') # noqa
@click.option('--seed', default=None, help='The random seed.') # noqa
@click.option('--log-level', default='WARNING', help='The log level, DEBUG, INFO, WARNING, ERROR or CRITICAL. Default is WARNING.') # noqa
@click.option('--debug', is_flag=True, help='Whether to show debug logs or not.') # noqa
def main(
Expand All @@ -38,9 +39,14 @@ def main(
model: str,
printer: str,
sub_model: str,
seed: int | None,
log_level: str,
debug: bool,
):
if seed is not None:
import random
random.seed(seed)

load_dotenv(),
if os.environ.get('OPENAI_API_KEY') is None:
raise ValueError('You must set OPENAI_API_KEY in your environment variables or .env file.') # noqa
Expand Down
10 changes: 7 additions & 3 deletions werewolf/utils/openai.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
from functools import lru_cache
from langchain_openai import ChatOpenAI
from ..const import DEFAULT_MODEL


@lru_cache(maxsize=None)
def create_chat_openai_model(
llm: ChatOpenAI | str | None = None,
seed: int | None = None,
) -> ChatOpenAI:
"""Create a ChatOpenAI instance.
Args:
llm (ChatOpenAI | str | None, optional): ChatOpenAI instance or model name. Defaults to None.
seed (int, optional): Random seed. Defaults to None.
Returns:
ChatOpenAI: ChatOpenAI instance
Note:
seed is used only when llm is a str or None.
""" # noqa
if isinstance(llm, str):
return ChatOpenAI(model=llm)
return ChatOpenAI(model=llm, seed=seed)
else:
return llm or ChatOpenAI(model=DEFAULT_MODEL)
return llm or ChatOpenAI(model=DEFAULT_MODEL, seed=seed)

0 comments on commit 368f652

Please sign in to comment.