Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update printer functions and game configuration #8

Merged
merged 1 commit into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions tests/utils/test_printer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
from werewolf.utils.printer import (
_print_dict,
create_print_func,
)


@pytest.mark.parametrize(
'key,expected',
list(_print_dict.items()),
)
def test_create_print_func(key, expected):
assert create_print_func(key) is expected


def test_create_print_func_invalid_key():
with pytest.raises(ValueError):
create_print_func('invalid_key')
12 changes: 7 additions & 5 deletions werewolf/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from typing import Iterable

import autogen
import click
from langchain_openai import ChatOpenAI

from .const import DEFAULT_MODEL, EGameMaster
from .config import GameConfig
from .game_master.base import BaseGameMaster
from .utils.openai import create_chat_openai_model
from .utils.printer import create_print_func


def game(
Expand All @@ -23,10 +23,12 @@ def game(
open_game: bool = False,
config_list=[{'model': DEFAULT_MODEL}],
llm: ChatOpenAI | str | None = None,
printer: str = 'click.echo',
log_file: str = 'werewolf.log',
logger: logging.Logger = logging.getLogger(__name__),
):
# preparation
print_func = create_print_func(printer)
llm = create_chat_openai_model(llm)
master = BaseGameMaster.instantiate(
EGameMaster.Default, # TODO
Expand Down Expand Up @@ -62,7 +64,7 @@ def game(
days = range(len(master.alive_players))
for _ in days:
# announce day
click.echo(f'=============================== Day {master.day} (Daytime) ================================') # noqa
print_func(f'=============================== Day {master.day} (Daytime) ================================') # noqa
master.announce(
'\n'.join([
f'Day {master.day}: Daytime.',
Expand All @@ -77,7 +79,7 @@ def game(
votes = master.daytime_discussion()
# exclude from the game
excluded_result = master.exclude_players_following_votes(votes)
click.echo('\n'.join([
print_func('\n'.join([
'============================== Excluded result ==============================', # noqa
str(excluded_result),
'=============================================================================', # noqa
Expand All @@ -89,7 +91,7 @@ def game(
break
# announce day
# announce night
click.echo(f'================================ Day {master.day} (Nighttime) ================================') # noqa
print_func(f'================================ Day {master.day} (Nighttime) ================================') # noqa
master.announce(
'\n'.join([
f'Day {master.day}: Nighttime.',
Expand All @@ -104,7 +106,7 @@ def game(
votes = master.nighttime_action()
# exclude from the game
excluded_result = master.exclude_players_following_votes(votes, announce_votes=False) # noqa
click.echo('\n'.join([
print_func('\n'.join([
'============================== Excluded result ==============================', # noqa
str(excluded_result),
'=============================================================================', # noqa
Expand Down
13 changes: 11 additions & 2 deletions werewolf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .const import DEFAULT_MODEL, ESpeakerSelectionMethod
from .game import game
from .utils.printer import KEYS_FOR_PRINTER, create_print_func


@click.command()
Expand All @@ -20,6 +21,7 @@
@click.option('-o', '--open-game', is_flag=True, help='Whether to open game or not.') # noqa
@click.option('-l', '--log', default=None, help='The log file name. Default is werewolf%Y%m%d%H%M%S.log') # noqa
@click.option('-m', '--model', default=DEFAULT_MODEL, help=f'The model name. Default is {DEFAULT_MODEL}.') # noqa
@click.option('-p', '--printer', default='click.echo', help=f'The printer name. The valid values is in {KEYS_FOR_PRINTER}. Default is click.echo.') # noqa
@click.option('--sub-model', default=DEFAULT_MODEL, help=f'The sub-model name. Default is {DEFAULT_MODEL}.') # noqa
@click.option('--log-level', default='WARNING', help='The log level, DEBUG, INFO, WARNING, ERROR or CRITICAL. Default is WARNING.') # noqa
@click.option('--debug', is_flag=True, help='Whether to show debug logs or not.') # noqa
Expand All @@ -34,16 +36,21 @@ def main(
open_game: bool,
log: str,
model: str,
printer: str,
sub_model: str,
log_level: str,
debug: bool,
):
load_dotenv(),
if os.environ.get('OPENAI_API_KEY') is None:
raise ValueError('You must set OPENAI_API_KEY in your environment variables or .env file.') # noqa

printer_func = create_print_func(printer)

log = f'werewolf{dt.now().strftime("%Y%m%d%H%M%S")}.log' if log is None else log # noqa
logging.basicConfig(level=logging.DEBUG if debug else getattr(logging, log_level.upper(), 'WARNING')) # type: ignore # noqa
config_list = [{'model': model}]

result = game(
n_players=n_players,
n_werewolves=n_werewolves,
Expand All @@ -53,12 +60,14 @@ def main(
speaker_selection_method=speaker_selection_method.value,
include_human=include_human,
open_game=open_game,
printer=printer,
log_file=log,
config_list=config_list,
llm=sub_model,
)
click.echo('================================ Game Result ================================') # noqa
click.echo(result)

printer_func('================================ Game Result ================================') # noqa
printer_func(result)


if __name__ == "__main__":
Expand Down
31 changes: 31 additions & 0 deletions werewolf/utils/printer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from functools import partial
import logging
from typing import Callable
import click


_print_dict: dict[str, Callable[..., None]] = {
'print': print,
'click.echo': click.echo,
'logging.info': logging.info,
}

KEYS_FOR_PRINTER: tuple[str, ...] = tuple(_print_dict.keys())


def create_print_func(key: str, **kwargs) -> Callable[..., None]:
"""Create a print function.

Args:
key (str): key of the print function
**kwargs: keyword arguments to pass to the print function

Returns:
Callable[..., None]: print function
"""
try:
if kwargs:
return partial(_print_dict[key], **kwargs)
return _print_dict[key]
except KeyError:
raise ValueError(f'Invalid key: {key}. Valid keys are {list(KEYS_FOR_PRINTER)}') # noqa
Loading